File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed
jax/_src/pallas/mosaic_gpu Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -978,7 +978,7 @@ def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int:
978978 return (self .swizzle * 8 ) // mgpu .bitwidth (dtype )
979979
980980 def commute_transpose (
981- self , _ : jax_core .ShapedArray , transpose : state_types .TransposeTransform
981+ self , _ : jax_core .AbstractValue , transpose : state_types .TransposeTransform
982982 ) -> tuple [state_types .TransposeTransform , UnswizzleRef ]:
983983 perm = transpose .permutation
984984 if perm [- 1 ] != len (perm ) - 1 :
Original file line number Diff line number Diff line change @@ -1454,13 +1454,14 @@ def _commute_transform(
14541454 gpu_core .UnswizzleRef () as t1 ,
14551455 state_types .ReshapeTransform () as t2 ,
14561456 ):
1457- assert isinstance (aval , jax_core .ShapedArray )
14581457 new_reshape , new_unswizzle = t1 .commute_reshape (aval , t2 )
14591458 return new_reshape , new_unswizzle
14601459 case (
14611460 gpu_core .UntilingTransform () | gpu_core .UnswizzleRef () as t1 ,
14621461 gpu_core .TransposeTransform () as t2 ,
14631462 ):
1463+ if isinstance (aval , state_types .AbstractRef ):
1464+ aval = aval .inner_aval
14641465 assert isinstance (aval , jax_core .ShapedArray )
14651466 new_reshape , new_unswizzle = t1 .commute_transpose (aval , t2 )
14661467 return new_reshape , new_unswizzle
You can’t perform that action at this time.
0 commit comments