diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 6dfde68aa3ce..9000ee43e405 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -596,6 +596,8 @@ def _launch( c(profiler_start, index), lowering_semantics, ) + if lowering_semantics == LoweringSemantics.Warpgroup: + prof_smem = dialect.with_transforms(prof_smem, ir.ArrayAttr.get([])) prof = profiler.OnDeviceProfiler( profiler_spec, prof_smem, maybe_prof_buffer ) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 4313da4d64d6..aa3e28070291 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1274,7 +1274,6 @@ def _mgpu_slice_smem_op_lowering_rule( ) -> Sequence[ir.Value]: del ctx sliced_ref = _slice_smem(op.result.type, op.offset) - memref_ty = ir.MemRefType(sliced_ref.type) if ( memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier") @@ -1380,7 +1379,11 @@ def _memref_subview_op_lowering_rule( del ctx in_transforms = inference_utils.in_transforms(op)[0] - out_transforms = inference_utils.out_transforms(op)[0] + if inference_utils.is_transformable_smem_memref(op.result): + out_transforms = inference_utils.out_transforms(op)[0] + else: + # This can happen for e.g. memref of rank 0. + out_transforms = ir.ArrayAttr.get([]) if in_transforms != out_transforms: raise NotImplementedError( diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 889e2bfa71dc..bee8bb389427 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -222,6 +222,7 @@ def is_transformable_smem_memref(v: ir.Value) -> bool: # barriers have no business being transformed and v.type.element_type != barrier_ty # pylint: disable=attribute-error and utils.is_smem_ref(v) + and v.type.rank != 0 ) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 046f544886ad..7c3dbcfb3236 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -31,6 +31,7 @@ from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import scf import numpy as np from .utils import * # noqa: F403 @@ -315,7 +316,6 @@ def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Va self.entries_per_wg, ), ) - self.smem_buffer_ptr = memref_ptr(self.smem_buffer, memory_space=3) self.gmem_buffer = gmem_buffer self.is_profiling_thread = arith.cmpi( arith.CmpIPredicate.eq, @@ -323,23 +323,23 @@ def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Va c(0, i32), ) # Hopefully mem2reg will remove the allocation. - self.offset = memref.alloca(ir.MemRefType.get((), i32), [], []) - memref.store(c(0, i32), self.offset, []) + self.offset = memref.alloca(ir.MemRefType.get((), index), [], []) + memref.store(c(0, index), self.offset, []) @contextlib.contextmanager def record(self, name: str): i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() name_id = self.spec.intern_name(name) def store(modifier): # smem_buffer[offset] = modifier | name_id # smem_buffer[offset + 1] = %clock # offset += 2 offset = memref.load(self.offset, []) + base_ref = memref_slice(self.smem_buffer, offset) + base_ptr = memref_ptr(base_ref, memory_space=3) i64 = ir.IntegerType.get_signless(64) - base_addr = arith.addi( - llvm.ptrtoint(i64, self.smem_buffer_ptr), - arith.extui(i64, arith.muli(offset, c(4, i32))), - ) + base_addr = llvm.ptrtoint(i64, base_ptr) llvm.inline_asm( ir.Type.parse("!llvm.void"), [self.is_profiling_thread, base_addr, c(modifier | name_id, i32)], @@ -349,7 +349,7 @@ def store(modifier): "b,l,r", has_side_effects=True, ) - new_offset = arith.addi(offset, c(2, i32)) + new_offset = arith.addi(offset, c(2, index)) memref.store(new_offset, self.offset, []) store(ProfilerSpec.ENTER) yield @@ -379,11 +379,18 @@ def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]): with when(self.is_profiling_thread): memref.store(self.start, wg_gmem_buffer, [c(0, index)]) memref.store(smid(), wg_gmem_buffer, [c(1, index)]) - num_traces = memref.load(self.offset, []) + num_traces = arith.index_cast(i32, memref.load(self.offset, [])) memref.store(num_traces, wg_gmem_buffer, [c(2, index)]) - traces = vector.load( - ir.VectorType.get((self.entries_per_wg - 3,), i32), - self.smem_buffer, - [c(0, index)], + for_op = scf.ForOp( + c(0, index), + c(self.entries_per_wg - 3, index), + c(1, index), ) - vector.store(traces, wg_gmem_buffer, [c(3, index)]) + with ir.InsertionPoint(for_op.body): + x = memref.load(self.smem_buffer, [for_op.induction_variable]) + memref.store( + x, + wg_gmem_buffer, + [arith.addi(for_op.induction_variable, c(3, index))], + ) + scf.yield_([]) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index bd2bb662755f..16b971cba374 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -21,7 +21,6 @@ from collections.abc import Callable from functools import partial import math -from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -106,9 +105,11 @@ def _resolve_transforms( def _transforms_from_uses(op: ir.OpView) -> ir.ArrayAttr | None: - transforms = None + if not inference_utils.is_transformable_smem_memref(op.result): + return None - for result_use in cast(ir.OpResult, op.result).uses: + transforms = None + for result_use in ir.OpResult(op.result).uses: consumer = result_use.owner op_user = consumer.operands[result_use.operand_number] user_transforms = inference_utils.in_transforms_for_operand( @@ -314,7 +315,7 @@ def _infer_memref_subview_transforms( in_transforms = inference_utils.value_transforms(op.source) transforms = _resolve_transforms(transforms, in_transforms) - if transforms is None: + if not transforms: return None # Here, we have some transforms to propagate one way or the other. For now, @@ -407,14 +408,12 @@ def _infer_memref_transpose_transforms( return [ir.ArrayAttr.get(in_transforms)], [out_transforms] -# `memref.load` is used to load barrier phases---the rule needn't do anything -# interesting, but we need to have it in order to avoid crashing on it. @partial(_add_transform_inference_rule, memref.LoadOp) def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: - if not ir.MemRefType(op.memref.type).shape: - # memref.load returns a scalar, so there is nothing interesting to do here. + in_transforms = inference_utils.value_transforms(op.memref) + if in_transforms is None: return None - raise NotImplementedError("Non-scalar memref.load transforms") + return [in_transforms], [] @partial(_add_transform_inference_rule, memref.CastOp) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 49585faf590d..df385079b3b0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -4566,6 +4566,36 @@ def create_kernel(): x[slicing].reshape(sub_shape), ) + def test_profiler(self): + def body(ctx, input, result, scratch): + del scratch + with ctx.named_region("load"): + reg = vector_load(input) + with ctx.named_region("store"): + vector_store(reg, result) + + dtype = jnp.bfloat16 + shape = (128, 128) + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + with tempfile.TemporaryDirectory() as tmpdir: + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax_shape), + out_shape=jax_shape, + smem_scratch_shape=[], + prof_spec=profiler.ProfilerSpec(1024, dump_path=tmpdir), + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + param = self.prng.uniform(-1, 1, shape).astype(dtype) + self.assertArraysEqual(kernel(param), param) + [name] = os.listdir(tmpdir) + with open(os.path.join(tmpdir, name)) as f: + data = f.read() + self.assertEqual(data.count('"name": "load"'), 2) + self.assertEqual(data.count('"name": "store"'), 2) + class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 10286ff4e17c..b0b5d5a91d56 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1719,8 +1719,6 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), x) def test_profiler(self): - self.skip_if_wg_semantics() # Transform inference not implemented. - def kernel(x_ref, o_ref): with jax.named_scope("add"): with jax.named_scope("load"):