|
31 | 31 | from jax._src.lib.mlir.dialects import arith |
32 | 32 | from jax._src.lib.mlir.dialects import scf |
33 | 33 | from jax._src.lib.mlir.dialects import vector |
| 34 | +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member |
34 | 35 | from jax.experimental.mosaic.gpu import fragmented_array as fa |
35 | 36 | import jax.numpy as jnp |
36 | 37 | import numpy as np |
@@ -165,8 +166,11 @@ def setUp(self): |
165 | 166 | self.skipTest("Only works on GPU with capability >= sm90") |
166 | 167 | super().setUp() |
167 | 168 | self.prng = np.random.default_rng(1234) |
| 169 | + self.context = mlir.make_ir_context() |
| 170 | + if mgpu_dialect is not None: |
| 171 | + mgpu_dialect.register_dialect(self.context) |
168 | 172 | self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) |
169 | | - self.enter_context(mlir.make_ir_context()) |
| 173 | + self.enter_context(self.context) |
170 | 174 | self.enter_context(ir.Location.unknown()) |
171 | 175 |
|
172 | 176 |
|
@@ -1854,5 +1858,51 @@ def get_reg(addr): |
1854 | 1858 | self.assertLessEqual(len(used_regs), expected_regs) |
1855 | 1859 |
|
1856 | 1860 |
|
| 1861 | +class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): |
| 1862 | + """Device tests with lowering from the MLIR dialect and layout inference.""" |
| 1863 | + |
| 1864 | + def setUp(self): |
| 1865 | + if mgpu_dialect is None: |
| 1866 | + raise self.skipTest("Test requires Mosaic GPU dialect") |
| 1867 | + super().setUp() |
| 1868 | + |
| 1869 | + def test_pointwise_kernel(self): |
| 1870 | + def add(ctx, a, b, result, smem): |
| 1871 | + del ctx, smem |
| 1872 | + shape = ir.MemRefType(a.type).shape |
| 1873 | + elt_type = ir.MemRefType(a.type).element_type |
| 1874 | + |
| 1875 | + zero_index = arith.constant(ir.IndexType.get(), 0) |
| 1876 | + |
| 1877 | + # GMEM -> registers |
| 1878 | + ab_type = ir.VectorType.get(shape, elt_type) |
| 1879 | + a = vector.load(ab_type, a, [zero_index, zero_index]) |
| 1880 | + b = vector.load(ab_type, b, [zero_index, zero_index]) |
| 1881 | + |
| 1882 | + # Computation |
| 1883 | + add = arith.addf(a, b) |
| 1884 | + |
| 1885 | + # Registers -> GMEM |
| 1886 | + vector.store(add, result, [zero_index, zero_index]) |
| 1887 | + |
| 1888 | + dtype = jnp.bfloat16 |
| 1889 | + shape = (128, 128) |
| 1890 | + jax_shape = jax.ShapeDtypeStruct(shape, dtype) |
| 1891 | + kernel = mgpu.as_gpu_kernel( |
| 1892 | + add, |
| 1893 | + grid=(1, 1, 1), |
| 1894 | + block=(128, 1, 1), |
| 1895 | + in_shape=(jax_shape, jax_shape), |
| 1896 | + out_shape=jax_shape, |
| 1897 | + smem_scratch_shape=[], |
| 1898 | + thread_semantics=mgpu.ThreadSemantics.Warpgroup, |
| 1899 | + ) |
| 1900 | + |
| 1901 | + x = self.prng.uniform(-1, 1, shape).astype(dtype) |
| 1902 | + y = self.prng.uniform(-1, 1, shape).astype(dtype) |
| 1903 | + |
| 1904 | + self.assertArraysEqual(jax.jit(kernel)(x, y), x + y) |
| 1905 | + |
| 1906 | + |
1857 | 1907 | if __name__ == "__main__": |
1858 | 1908 | absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments