Skip to content

Commit bb271aa

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Added FragmentedArray.to_layout
PiperOrigin-RevId: 686524192
1 parent 1222b4a commit bb271aa

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,20 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
232232
match layout:
233233
case WGMMARowFragLayout():
234234
if len(shape) != 1:
235-
raise ValueError
235+
raise ValueError("WGMMARowFragLayout requires a 1D shape")
236236
if shape[0] % 64:
237-
raise ValueError
237+
raise ValueError(
238+
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
239+
)
238240
reg_shape = (shape[0] // 64, 2)
239241
case WGMMAFragLayout():
240242
if len(shape) != 2:
241-
raise ValueError
243+
raise ValueError("WGMMAFragLayout requires a 2D shape")
242244
if shape[0] % 64 or shape[1] % 8:
243-
raise ValueError
245+
raise ValueError(
246+
"WGMMAFragLayout requires shape[0] to be a multiple of 64, and"
247+
" shape[1] to be a multiple of 8"
248+
)
244249
reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1)
245250
value = vector.splat(ir.VectorType.get((2,), value.type), value)
246251
case WGStridedFragLayout(vec_size=vec_size):
@@ -283,6 +288,22 @@ def mlir_dtype(self):
283288
case WGMMARowFragLayout() | WGSplatFragLayout():
284289
return reg_ty
285290

291+
def to_layout(self, new_layout: FragmentedLayout):
292+
"""Converts the fragmented array to the given layout.
293+
294+
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
295+
"""
296+
if self.layout == new_layout:
297+
return self
298+
if not isinstance(self.layout, WGSplatFragLayout):
299+
raise NotImplementedError(
300+
f"Cannot convert from {self.layout} to {new_layout}"
301+
)
302+
[reg] = self.registers.flat
303+
return type(self).splat(
304+
reg, self.shape, new_layout, is_signed=self.is_signed
305+
)
306+
286307
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
287308
is_signed = (
288309
output_is_signed if output_is_signed is not None else self.is_signed

tests/mosaic/gpu_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,6 +1462,18 @@ def kernel(ctx, inp, out, smem):
14621462
)(x)
14631463
np.testing.assert_array_equal(result, reference)
14641464

1465+
@parameterized.parameters(
1466+
([64 * 4], "WGMMA_ROW_LAYOUT"),
1467+
([64 * 4, 8 * 2], "WGMMA_LAYOUT"),
1468+
)
1469+
def test_to_layout(self, shape, new_layout):
1470+
def kernel(ctx, _):
1471+
# No assertions, we are just checking there are no compile-time errors.
1472+
arr = mgpu.FragmentedArray.splat(c(42.0, ir.F32Type.get()), shape)
1473+
arr.to_layout(getattr(mgpu, new_layout))
1474+
1475+
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)()
1476+
14651477

14661478
class ProfilerTest(TestCase):
14671479

0 commit comments

Comments
 (0)