@@ -3483,7 +3483,7 @@ def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *_args, **_kwargs):
34833483 raise ValueError ("async_copy_scales_to_tmem source must be an SMEM ref" )
34843484 if tmem_ref .memory_space != gpu_core .MemorySpace .TMEM :
34853485 raise ValueError ("async_copy_scales_to_tmem target must be a TMEM ref" )
3486- return (), {gpu_core . _memory_effect }
3486+ return (), {state_types . ReadEffect ( 0 ), state_types . WriteEffect ( 1 ) }
34873487
34883488def _async_copy_to_tmem_lowering_rule (
34893489 impl , ctx : lowering .LoweringRuleContext , smem_ref , tmem_ref , * leaves , smem_tree , tmem_tree , collective_axis
@@ -3552,6 +3552,145 @@ def _async_copy_sparse_metadata_to_tmem_lowering_rule(*args, **kwargs):
35523552 )
35533553
35543554
3555+ async_copy_smem_to_tmem_p = jax_core .Primitive ("async_copy_smem_to_tmem" )
3556+ async_copy_smem_to_tmem_p .multiple_results = True
3557+
3558+
3559+ def async_copy_smem_to_tmem (
3560+ smem_ref : _Ref ,
3561+ tmem_ref : _Ref ,
3562+ collective_axis : AxisName | None = None ,
3563+ ):
3564+ """Copies data from SMEM to TMEM using the tcgen05.cp instruction.
3565+
3566+ The source SMEM ref must have tiling and swizzle transforms applied. The
3567+ destination TMEM ref must use packed layout (i.e. ``packed=True`` for sub-32b
3568+ types).
3569+
3570+ The copy is performed asynchronously. It can be awaited by calling
3571+ ``tcgen05_commit_arrive`` and waiting on the specified barrier. No
3572+ synchronization is necessary if the target of the copy is used by a
3573+ tcgen05_mma operation.
3574+
3575+ Args:
3576+ smem_ref: The SMEM reference to copy from.
3577+ tmem_ref: The TMEM reference to copy into.
3578+ collective_axis: The name of the cluster axis along which the
3579+ copy should be performed collectively. The cluster axis should have a
3580+ size of exactly 2, and must be on the minormost cluster axis.
3581+ """
3582+ smem_ref , smem_transforms = state_primitives .get_ref_and_transforms (
3583+ smem_ref , None , "async_copy_smem_to_tmem"
3584+ )
3585+ flat_smem_transforms , smem_transforms_treedef = tree_util .tree_flatten (
3586+ smem_transforms
3587+ )
3588+ tmem_ref , tmem_transforms = state_primitives .get_ref_and_transforms (
3589+ tmem_ref , None , "async_copy_smem_to_tmem"
3590+ )
3591+ flat_tmem_transforms , tmem_transforms_treedef = tree_util .tree_flatten (
3592+ tmem_transforms
3593+ )
3594+ async_copy_smem_to_tmem_p .bind (
3595+ smem_ref , tmem_ref , * flat_smem_transforms , * flat_tmem_transforms ,
3596+ smem_tree = smem_transforms_treedef , tmem_tree = tmem_transforms_treedef ,
3597+ collective_axis = collective_axis ,
3598+ )
3599+
3600+
3601+ @async_copy_smem_to_tmem_p .def_effectful_abstract_eval
3602+ def _async_copy_smem_to_tmem_abstract_eval (
3603+ smem_ref , tmem_ref , * args , smem_tree , ** _kwargs
3604+ ):
3605+ if smem_ref .memory_space != gpu_core .MemorySpace .SMEM :
3606+ raise ValueError ("async_copy_smem_to_tmem source must be an SMEM ref" )
3607+ if tmem_ref .memory_space != gpu_core .MemorySpace .TMEM :
3608+ raise ValueError ("async_copy_smem_to_tmem target must be a TMEM ref" )
3609+ smem_transforms = jax .tree .unflatten (smem_tree , args [:smem_tree .num_leaves ])
3610+ smem_aval = smem_ref
3611+ for t in smem_transforms :
3612+ smem_aval = t .transform_type (smem_aval )
3613+ if smem_aval .dtype != tmem_ref .dtype :
3614+ raise ValueError (
3615+ f"Expected SMEM element type ({ smem_aval .dtype } ) to equal the TMEM"
3616+ f" element type ({ tmem_ref .dtype } )"
3617+ )
3618+ if smem_aval .shape != tmem_ref .shape :
3619+ raise ValueError (
3620+ f"Expected SMEM reference shape { smem_aval .shape } to equal the TMEM"
3621+ f" reference shape { tmem_ref .shape } "
3622+ )
3623+ return (), {state_types .ReadEffect (0 ), state_types .WriteEffect (1 )}
3624+
3625+
3626+ @lowering .register_lowering_rule (
3627+ async_copy_smem_to_tmem_p , mgpu .LoweringSemantics .Lane
3628+ )
3629+ @lowering .register_lowering_rule (
3630+ async_copy_smem_to_tmem_p ,
3631+ mgpu .LoweringSemantics .Lane ,
3632+ gpu_core .PrimitiveSemantics .Warp ,
3633+ )
3634+ def _async_copy_smem_to_tmem_lowering_rule (
3635+ ctx : lowering .LoweringRuleContext , smem_ref , tmem_ref , * leaves ,
3636+ smem_tree , tmem_tree , collective_axis ,
3637+ ):
3638+ assert isinstance (tmem_ref , tcgen05 .TMEMRef )
3639+ smem_leaves , tmem_leaves = util .split_list (leaves , [smem_tree .num_leaves ])
3640+ smem_transforms = jax .tree .unflatten (smem_tree , smem_leaves )
3641+ tmem_transforms = jax .tree .unflatten (tmem_tree , tmem_leaves )
3642+ smem_aval = ctx .avals_in [0 ]
3643+ assert isinstance (smem_aval , state_types .AbstractRef )
3644+ tmem_aval = ctx .avals_in [1 ]
3645+ assert isinstance (tmem_aval , state_types .AbstractRef )
3646+ transform_avals = util .split_list (
3647+ ctx .avals_in [2 :], [smem_tree .num_leaves ]
3648+ )
3649+ smem_transform_avals = smem_tree .unflatten (transform_avals [0 ])
3650+ tmem_transform_avals = tmem_tree .unflatten (transform_avals [1 ])
3651+ smem_ref , transformed_smem_aval , smem_transforms = lowering ._handle_transforms (
3652+ ctx , smem_aval , smem_ref , smem_transform_avals , smem_transforms ,
3653+ handle_transposes = False
3654+ )
3655+ tmem_ref , _ , tmem_transforms = lowering ._handle_transforms (
3656+ ctx , tmem_aval , tmem_ref , tmem_transform_avals , tmem_transforms
3657+ )
3658+ match smem_transforms :
3659+ case (
3660+ gpu_core .UnswizzleRef (swizzle ),
3661+ gpu_core .UntilingTransform (tiling ),
3662+ ):
3663+ pass
3664+ case _:
3665+ raise NotImplementedError (
3666+ f"Unsupported transforms for SMEM ref: { smem_transforms } "
3667+ )
3668+ swizzle_elems = 8 * swizzle // dtypes .itemsize_bits (transformed_smem_aval .dtype )
3669+ if tiling != (8 , swizzle_elems ):
3670+ raise ValueError (
3671+ f"Tiling does not fit swizzle: expected (8, { swizzle_elems } ), but got"
3672+ f" { tiling } "
3673+ )
3674+ if tmem_transforms :
3675+ raise NotImplementedError (
3676+ f"Unimplemented transforms for TMEM refs: { tmem_transforms } "
3677+ )
3678+
3679+ predicate = ctx .module_ctx .single_lane_predicate
3680+ if collective_axis is not None :
3681+ is_leader_block = _collective_mma_predicate (ctx , collective_axis )
3682+ predicate = arith_dialect .andi (predicate , is_leader_block )
3683+ collective = True
3684+ else :
3685+ collective = False
3686+
3687+ with mgpu .when (predicate ):
3688+ tcgen05 .async_copy_smem_to_tmem (
3689+ smem_ref , tmem_ref , swizzle , collective = collective
3690+ )
3691+ return ()
3692+
3693+
35553694semaphore_signal_parallel_p = jax_core .Primitive ('semaphore_signal_parallel' )
35563695semaphore_signal_parallel_p .multiple_results = True
35573696
0 commit comments