Skip to content

Commit cd01971

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
[pallas:mgpu] Fix handling of AbstractRef in _commute_transform.
PiperOrigin-RevId: 873912326
1 parent f47f0c2 commit cd01971

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,7 @@ def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int:
978978
return (self.swizzle * 8) // mgpu.bitwidth(dtype)
979979

980980
def commute_transpose(
981-
self, _: jax_core.ShapedArray, transpose: state_types.TransposeTransform
981+
self, _: jax_core.AbstractValue, transpose: state_types.TransposeTransform
982982
) -> tuple[state_types.TransposeTransform, UnswizzleRef]:
983983
perm = transpose.permutation
984984
if perm[-1] != len(perm) - 1:

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,13 +1454,14 @@ def _commute_transform(
14541454
gpu_core.UnswizzleRef() as t1,
14551455
state_types.ReshapeTransform() as t2,
14561456
):
1457-
assert isinstance(aval, jax_core.ShapedArray)
14581457
new_reshape, new_unswizzle = t1.commute_reshape(aval, t2)
14591458
return new_reshape, new_unswizzle
14601459
case (
14611460
gpu_core.UntilingTransform() | gpu_core.UnswizzleRef() as t1,
14621461
gpu_core.TransposeTransform() as t2,
14631462
):
1463+
if isinstance(aval, state_types.AbstractRef):
1464+
aval = aval.inner_aval
14641465
assert isinstance(aval, jax_core.ShapedArray)
14651466
new_reshape, new_unswizzle = t1.commute_transpose(aval, t2)
14661467
return new_reshape, new_unswizzle

0 commit comments

Comments
 (0)