@@ -364,7 +364,7 @@ class _WGMMAPipelineEffect(effects.Effect):
364364
365365def wgmma (
366366 acc : gpu_core .WGMMAAbstractAccumulatorRef ,
367- a : pallas_core . TransformedRef ,
367+ a ,
368368 b : pallas_core .TransformedRef ,
369369) -> None :
370370 """Performs and asynchronous warp group matmul-accumulate on the given references.
@@ -395,12 +395,16 @@ def wgmma(
395395 if a .dtype != b .dtype :
396396 raise ValueError (f"Mixed input dtypes for matrix multiplication unsupported: lhs={ a .dtype } , rhs={ b .dtype } " )
397397
398- a_transforms_leaves , a_transforms_tree = jax .tree .flatten (a .transforms )
398+ if isinstance (a , pallas_core .TransformedRef ):
399+ a_transforms_leaves , a_transforms_tree = jax .tree .flatten (a .transforms )
400+ a = a .ref
401+ else :
402+ a_transforms_leaves , a_transforms_tree = [], None
399403 b_transforms_leaves , b_transforms_tree = jax .tree .flatten (b .transforms )
400404
401405 wgmma_ref_p .bind (
402406 acc ,
403- a . ref ,
407+ a ,
404408 b .ref ,
405409 * a_transforms_leaves ,
406410 * b_transforms_leaves ,
@@ -411,15 +415,15 @@ def wgmma(
411415
412416@wgmma_ref_p .def_effectful_abstract_eval
413417def _wgmma_ref_effectful_abstract_eval (acc_aval , a_aval , b_aval , * _ , ** params ):
414- del a_aval , b_aval , params
418+ del b_aval , params
415419 if not isinstance (acc_aval , gpu_core .WGMMAAbstractAccumulatorRef ):
416420 raise TypeError (f"Expected WGMMAAbstractAccumulatorRef got { acc_aval } " )
417421 return (), {
418422 _wgmma_pipeline_effect ,
419423 state .WriteEffect (0 ),
420424 state .ReadEffect (0 ),
421- state .ReadEffect (1 ),
422425 state .ReadEffect (2 ),
426+ * ([state .ReadEffect (1 )] if isinstance (a_aval , state .AbstractRef ) else [])
423427 }
424428
425429
@@ -444,23 +448,31 @@ def _wgmma_lowering(
444448 b_transforms_tree ,
445449):
446450 _ , a_aval , * _ = ctx .avals_in
447- a_transforms_leaves , b_transforms_leaves = util .split_list (
448- transforms_leaves , [a_transforms_tree .num_leaves ]
449- )
450- a_transforms = a_transforms_tree .unflatten (a_transforms_leaves )
451- b_transforms = b_transforms_tree .unflatten (b_transforms_leaves )
451+ lhs_swizzle = None
452+ if a_transforms_tree is not None :
453+ a_transforms_leaves , b_transforms_leaves = util .split_list (
454+ transforms_leaves , [a_transforms_tree .num_leaves ]
455+ )
456+ a_transforms = a_transforms_tree .unflatten (a_transforms_leaves )
457+ a , a_transforms = lowering ._handle_indexing (a , a_transforms )
458+ match a_transforms :
459+ case (gpu_core .UnswizzleRef (lhs_swizzle ), gpu_core .UntileRef (tiling )):
460+ swizzle_elems = lhs_swizzle // a_aval .dtype .itemsize
461+ if tiling != (64 , swizzle_elems ):
462+ raise NotImplementedError ("WGMMA lhs tiling does not fit swizzle" )
463+ case _:
464+ raise ValueError (f"WGMMA lhs has unsupported transforms: { a_transforms } ." )
465+ else :
466+ b_transforms_leaves = transforms_leaves # type: ignore
467+ if not isinstance (a , mgpu .FragmentedArray ):
468+ raise ValueError (
469+ "When WGMMA lhs is passed in as a ref, it must be transformed by"
470+ " swizzling and tiling appropriately."
471+ )
452472
453- a , a_transforms = lowering . _handle_indexing ( a , a_transforms )
473+ b_transforms = b_transforms_tree . unflatten ( b_transforms_leaves )
454474 b , b_transforms = lowering ._handle_indexing (b , b_transforms )
455475
456- match a_transforms :
457- case (gpu_core .UnswizzleRef (swizzle ), gpu_core .UntileRef (tiling )):
458- swizzle_elems = swizzle // a_aval .dtype .itemsize
459- if tiling != (64 , swizzle_elems ):
460- raise NotImplementedError ("WGMMA lhs tiling does not fit swizzle" )
461- case _:
462- raise ValueError (f"WGMMA lhs has unsupported transforms: { a_transforms } ." )
463-
464476 match b_transforms :
465477 case (gpu_core .UnswizzleRef (rhs_swizzle ), gpu_core .UntileRef (rhs_tiling )):
466478 rhs_transpose = False
@@ -474,16 +486,18 @@ def _wgmma_lowering(
474486 case _:
475487 raise ValueError (f"WGMMA rhs has unsupported transforms: { b_transforms } ." )
476488
477- if rhs_swizzle != swizzle :
478- raise NotImplementedError ("WGMMA rhs swizzle must match lhs swizzle" )
479- if rhs_tiling != (swizzle_elems , swizzle_elems ):
480- raise NotImplementedError ("WGMMA rhs tiling does not fit swizzle" )
489+ if lhs_swizzle is not None :
490+ swizzle_elems = rhs_swizzle // a_aval .dtype .itemsize
491+ if rhs_swizzle != lhs_swizzle :
492+ raise NotImplementedError ("WGMMA rhs swizzle must match lhs swizzle" )
493+ if rhs_tiling != (swizzle_elems , swizzle_elems ):
494+ raise NotImplementedError ("WGMMA rhs tiling does not fit swizzle" )
481495
482496 new_acc = mgpu .wgmma (
483497 acc ,
484498 a ,
485499 b ,
486- swizzle = swizzle ,
500+ swizzle = rhs_swizzle ,
487501 b_order = mgpu .WGMMALayout .COL_MAJOR
488502 if rhs_transpose
489503 else mgpu .WGMMALayout .ROW_MAJOR ,
@@ -493,12 +507,12 @@ def _wgmma_lowering(
493507
494508
495509@wgmma_p .def_effectful_abstract_eval
496- def _wgmma_effectful_abstract_eval (acc , * args , ** kwargs ):
510+ def _wgmma_effectful_abstract_eval (acc , lhs_ref , * args , ** kwargs ):
497511 del args , kwargs
498512 return acc , {
499513 _wgmma_pipeline_effect ,
500- state .ReadEffect (1 ),
501514 state .ReadEffect (2 ),
515+ * ([state .ReadEffect (1 )] if isinstance (lhs_ref , state .AbstractRef ) else [])
502516 }
503517
504518wgmma_wait_p = jax_core .Primitive ("wgmma_wait" )
0 commit comments