@@ -107,10 +107,10 @@ def copy_smem_to_gmem(
107107 if dst .memory_space is not gpu_core .GMEM :
108108 raise ValueError (f"dst must be a GMEM reference, got { dst .memory_space } " )
109109 src , src_transforms = state_primitives .get_ref_and_transforms (
110- src , None , "copy_smem_to_gmem"
110+ src , None , "copy_smem_to_gmem" , force_trailing_indexer = False ,
111111 )
112112 dst , dst_transforms = state_primitives .get_ref_and_transforms (
113- dst , None , "copy_smem_to_gmem"
113+ dst , None , "copy_smem_to_gmem" , force_trailing_indexer = False ,
114114 )
115115 flat_src_transforms , src_transforms_treedef = tree_util .tree_flatten (
116116 src_transforms
@@ -193,10 +193,10 @@ def copy_gmem_to_smem(
193193 if dst .memory_space is not gpu_core .SMEM :
194194 raise ValueError (f"dst must be a SMEM reference, got { dst .memory_space } " )
195195 src , src_transforms = state_primitives .get_ref_and_transforms (
196- src , None , "copy_gmem_to_smem"
196+ src , None , "copy_gmem_to_smem" , force_trailing_indexer = False ,
197197 )
198198 dst , dst_transforms = state_primitives .get_ref_and_transforms (
199- dst , None , "copy_gmem_to_smem"
199+ dst , None , "copy_gmem_to_smem" , force_trailing_indexer = False ,
200200 )
201201 flat_src_transforms , src_transforms_treedef = tree_util .tree_flatten (
202202 src_transforms
@@ -205,7 +205,7 @@ def copy_gmem_to_smem(
205205 dst_transforms
206206 )
207207 barrier , barrier_transforms = state_primitives .get_ref_and_transforms (
208- barrier , None , "copy_gmem_to_smem"
208+ barrier , None , "copy_gmem_to_smem" , force_trailing_indexer = False ,
209209 )
210210 flat_barrier_transforms , barrier_transforms_treedef = tree_util .tree_flatten (
211211 barrier_transforms
@@ -284,7 +284,7 @@ def _barrier_arrive_lowering(
284284def barrier_arrive (barrier : pallas_core .AbstractMemoryRef ) -> None :
285285 """Arrives at the given barrier."""
286286 barrier , transforms = state_primitives .get_ref_and_transforms (
287- barrier , None , "barrier_arrive"
287+ barrier , None , "barrier_arrive" , force_trailing_indexer = False ,
288288 )
289289 flat_transforms , transforms_treedef = tree_util .tree_flatten (transforms )
290290 barrier_arrive_p .bind (
@@ -321,7 +321,7 @@ def _barrier_wait_lowering(
321321def barrier_wait (barrier : pallas_core .AbstractMemoryRef ) -> None :
322322 """Waits on the given barrier."""
323323 barrier , transforms = state_primitives .get_ref_and_transforms (
324- barrier , None , "barrier_wait"
324+ barrier , None , "barrier_wait" , force_trailing_indexer = False ,
325325 )
326326 flat_transforms , transforms_treedef = tree_util .tree_flatten (transforms )
327327 barrier_wait_p .bind (
0 commit comments