@@ -46,7 +46,16 @@ class BufferedRef:
4646 spec : pallas_core .BlockSpec = dataclasses .field (metadata = {"static" : True })
4747 is_index_invariant : bool = dataclasses .field (metadata = {"static" : True })
4848 gmem_ref : pallas_core .AbstractMemoryRef
49- smem_ref : pallas_core .AbstractMemoryRef # [num_slots, *spec.block_shape]
49+ # ``None`` if the ref is pinned to GMEM; otherwise, has shape
50+ # [num_slots, *spec.block_shape].
51+ smem_ref : pallas_core .AbstractMemoryRef | None
52+
53+ def get_ref_for_slot (
54+ self , slot : int | jax .Array
55+ ) -> pallas_core .AbstractMemoryRef :
56+ if self .smem_ref is None :
57+ return self .gmem_ref
58+ return self .smem_ref .at [slot ]
5059
5160 def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice , ...]:
5261 index_map = self .spec .index_map
@@ -59,6 +68,9 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
5968 )
6069
6170 def copy_in (self , slot , grid_indices , barrier_ref ):
71+ if not _in_smem (self .spec ):
72+ return
73+ assert self .smem_ref is not None
6274 gmem_slices = self .compute_gmem_slice (grid_indices )
6375 gpu_primitives .copy_gmem_to_smem (
6476 self .gmem_ref .at [gmem_slices ], # pytype: disable=unsupported-operands
@@ -67,6 +79,9 @@ def copy_in(self, slot, grid_indices, barrier_ref):
6779 )
6880
6981 def copy_out (self , slot , grid_indices , predicate = None ):
82+ if not _in_smem (self .spec ):
83+ return
84+ assert self .smem_ref is not None
7085 gmem_slices = self .compute_gmem_slice (grid_indices )
7186 gpu_primitives .copy_smem_to_gmem (
7287 self .smem_ref .at [slot ],
@@ -88,8 +103,8 @@ def _uses_arguments(
88103def _is_index_invariant (
89104 spec : pallas_core .BlockSpec , grid : pallas_core .StaticGrid
90105) -> bool :
91- index_map = spec .index_map
92- assert index_map is not None
106+ if ( index_map : = spec .index_map ) is None :
107+ return True
93108 return not any (_uses_arguments (index_map , len (grid )))
94109
95110
@@ -105,6 +120,10 @@ def _inc_grid_by_1(
105120 return tuple (reversed (next_indices ))
106121
107122
123+ def _in_smem (spec : pallas_core .BlockSpec ) -> bool :
124+ return spec .memory_space in (None , gpu_core .SMEM )
125+
126+
108127# ``pl.Slice`` uses a different pytree encoding, depending on whether the
109128# start/size are static or dynamic. This leads to pytree structure mismatch
110129# in the pipeline body. So, we define a different ``Slice`` class below.
@@ -166,6 +185,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
166185 if any (
167186 spec .block_shape [- idx ] * grid [- idx ] != gmem_ref .shape [- idx ] # type: ignore
168187 for idx in range (1 , len (grid ) + 1 )
188+ if spec .block_shape is not None
169189 ):
170190 raise NotImplementedError (
171191 f"Cannot emit a pipeline over the { grid = } for { gmem_ref } with block"
@@ -174,14 +194,12 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
174194
175195 in_gmem_refs , out_gmem_refs = util .split_list (gmem_refs , [len (in_specs )])
176196 in_smem_refs , out_smem_refs = util .split_list (
177- map (
178- lambda spec , ref : gpu_core .SMEM (
179- (max_concurrent_steps , * spec .block_shape ), # type: ignore
180- ref .dtype ,
181- ),
182- it .chain (in_specs , out_specs ),
183- gmem_refs ,
184- ),
197+ [
198+ gpu_core .SMEM ((max_concurrent_steps , * spec .block_shape ), ref .dtype ) # type: ignore
199+ if _in_smem (spec )
200+ else None
201+ for spec , ref in zip (it .chain (in_specs , out_specs ), gmem_refs )
202+ ],
185203 [len (in_specs )],
186204 )
187205 return pl .run_scoped (
@@ -194,7 +212,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
194212 out_smem_refs = out_smem_refs ,
195213 barrier_ref = gpu_core .Barrier (
196214 # TODO(slebedev): Change this to arrive only once.
197- len ( in_specs ),
215+ sum ( map ( _in_smem , in_specs ) ),
198216 num_barriers = max_concurrent_steps ,
199217 ),
200218 )
@@ -233,9 +251,10 @@ def loop_body(step, carry):
233251 )
234252
235253 with pallas_core .grid_env (map (pallas_core .GridAxis , indices , grid )):
236- body (
237- * (bref .smem_ref .at [slot ] for bref in it .chain (in_brefs , out_brefs ))
238- )
254+ body (* (
255+ bref .get_ref_for_slot (slot )
256+ for bref in it .chain (in_brefs , out_brefs )
257+ ))
239258
240259 if not all (bref .is_index_invariant for bref in out_brefs ):
241260 gpu_primitives .commit_smem ()
0 commit comments