Skip to content

Commit 36b12d5

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add end-to-end lowering example for a pointwise kernel using the dialect and layout inference.
Also implement a lowering rule for `arith.AddFOp`. PiperOrigin-RevId: 707131747
1 parent 473e2bf commit 36b12d5

File tree

6 files changed

+93
-2
lines changed

6 files changed

+93
-2
lines changed

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LaunchContext as LaunchContext,
2323
MemRefTransform as MemRefTransform,
2424
TMABarrier as TMABarrier,
25+
ThreadSemantics as ThreadSemantics,
2526
TileTransform as TileTransform,
2627
TransposeTransform as TransposeTransform,
2728
Union as Union,

jax/experimental/mosaic/gpu/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import contextlib
1818
import ctypes
1919
import dataclasses
20+
import enum
2021
import functools
2122
import hashlib
2223
import math
@@ -38,6 +39,15 @@
3839
from jaxlib.mlir.dialects import nvvm
3940
import numpy as np
4041

42+
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
43+
44+
if dialect is not None:
45+
from . import dialect_lowering
46+
from . import layout_inference
47+
else:
48+
dialect_lowering = None
49+
layout_inference = None
50+
4151
from . import profiler
4252
from . import utils
4353

@@ -942,6 +952,13 @@ def _declare_runtime_functions():
942952
)
943953

944954

955+
class ThreadSemantics(enum.Enum):
956+
"""Semantics for the kernel's instruction stream."""
957+
958+
Lane = enum.auto()
959+
Warpgroup = enum.auto()
960+
961+
945962
def as_gpu_kernel(
946963
body,
947964
grid: tuple[int, int, int],
@@ -953,6 +970,7 @@ def as_gpu_kernel(
953970
cluster: tuple[int, int, int] = (1, 1, 1),
954971
module_name: str = "unknown",
955972
kernel_name: str | None = None,
973+
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
956974
):
957975
if isinstance(in_shape, list):
958976
in_shape = tuple(in_shape)
@@ -966,6 +984,12 @@ def as_gpu_kernel(
966984
)
967985
)
968986

987+
if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None:
988+
# Run Python lowering passes. The remaining passes will be run in C++ in
989+
# jax/jaxlib/mosaic/gpu/custom_call.cc
990+
layout_inference.infer_layout(module) # pytype: disable=attribute-error
991+
dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error
992+
969993
expected_arg_treedef = jax.tree.structure(in_shape)
970994
def _check_args(*args):
971995
arg_treedef = jax.tree.structure(args)

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,19 @@ def _vector_store_op_lowering_rule(
237237
return []
238238

239239

240+
@_register_lowering(arith.AddFOp)
241+
def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
242+
243+
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
244+
fragmented_array_rhs = _fragmented_array_from_ir(add.rhs)
245+
246+
return [
247+
_fragmented_array_to_ir(
248+
fragmented_array_lhs + fragmented_array_rhs, add.result.type
249+
)
250+
]
251+
252+
240253
def lower_mgpu_dialect(module: ir.Module):
241254
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
242255
module.context.load_all_available_dialects()

jaxlib/mosaic/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ cc_library(
119119
":passes",
120120
":target",
121121
"//jaxlib/cuda:cuda_vendor",
122+
"//jaxlib/mosaic/dialect/gpu:mosaic_gpu",
122123
"@com_google_absl//absl/base:core_headers",
123124
"@com_google_absl//absl/cleanup",
124125
"@com_google_absl//absl/container:flat_hash_map",

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ limitations under the License.
8383
#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
8484
#include "mlir/include/mlir/Transforms/Passes.h"
8585
#include "jaxlib/gpu/vendor.h"
86+
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"
8687
#include "jaxlib/mosaic/gpu/launch_lowering.h"
8788
#include "jaxlib/mosaic/gpu/passes.h"
8889
#include "jaxlib/mosaic/gpu/target.h"
@@ -206,7 +207,8 @@ void InitContext(mlir::MLIRContext* context) {
206207
mlir::math::MathDialect, mlir::memref::MemRefDialect,
207208
mlir::scf::SCFDialect, mlir::vector::VectorDialect,
208209
mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect,
209-
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
210+
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect,
211+
mosaic_gpu::MosaicGPUDialect>();
210212
mlir::registerConvertNVVMToLLVMInterface(registry);
211213
mlir::registerConvertComplexToLLVMInterface(registry);
212214
mlir::registerConvertMemRefToLLVMInterface(registry);

tests/mosaic/gpu_test.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from jax._src.lib.mlir.dialects import arith
3232
from jax._src.lib.mlir.dialects import scf
3333
from jax._src.lib.mlir.dialects import vector
34+
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
3435
from jax.experimental.mosaic.gpu import fragmented_array as fa
3536
import jax.numpy as jnp
3637
import numpy as np
@@ -165,8 +166,11 @@ def setUp(self):
165166
self.skipTest("Only works on GPU with capability >= sm90")
166167
super().setUp()
167168
self.prng = np.random.default_rng(1234)
169+
self.context = mlir.make_ir_context()
170+
if mgpu_dialect is not None:
171+
mgpu_dialect.register_dialect(self.context)
168172
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
169-
self.enter_context(mlir.make_ir_context())
173+
self.enter_context(self.context)
170174
self.enter_context(ir.Location.unknown())
171175

172176

@@ -1854,5 +1858,51 @@ def get_reg(addr):
18541858
self.assertLessEqual(len(used_regs), expected_regs)
18551859

18561860

1861+
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
1862+
"""Device tests with lowering from the MLIR dialect and layout inference."""
1863+
1864+
def setUp(self):
1865+
if mgpu_dialect is None:
1866+
raise self.skipTest("Test requires Mosaic GPU dialect")
1867+
super().setUp()
1868+
1869+
def test_pointwise_kernel(self):
1870+
def add(ctx, a, b, result, smem):
1871+
del ctx, smem
1872+
shape = ir.MemRefType(a.type).shape
1873+
elt_type = ir.MemRefType(a.type).element_type
1874+
1875+
zero_index = arith.constant(ir.IndexType.get(), 0)
1876+
1877+
# GMEM -> registers
1878+
ab_type = ir.VectorType.get(shape, elt_type)
1879+
a = vector.load(ab_type, a, [zero_index, zero_index])
1880+
b = vector.load(ab_type, b, [zero_index, zero_index])
1881+
1882+
# Computation
1883+
add = arith.addf(a, b)
1884+
1885+
# Registers -> GMEM
1886+
vector.store(add, result, [zero_index, zero_index])
1887+
1888+
dtype = jnp.bfloat16
1889+
shape = (128, 128)
1890+
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
1891+
kernel = mgpu.as_gpu_kernel(
1892+
add,
1893+
grid=(1, 1, 1),
1894+
block=(128, 1, 1),
1895+
in_shape=(jax_shape, jax_shape),
1896+
out_shape=jax_shape,
1897+
smem_scratch_shape=[],
1898+
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
1899+
)
1900+
1901+
x = self.prng.uniform(-1, 1, shape).astype(dtype)
1902+
y = self.prng.uniform(-1, 1, shape).astype(dtype)
1903+
1904+
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
1905+
1906+
18571907
if __name__ == "__main__":
18581908
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)