1616
1717from __future__ import annotations
1818
19- from collections .abc import Sequence
19+ from collections .abc import Callable , Sequence
2020import dataclasses
2121import functools
2222import itertools as it
2525
2626import jax
2727from jax import lax
28+ from jax ._src import core
29+ from jax ._src import linear_util as lu
2830from jax ._src import util
31+ from jax ._src .interpreters import partial_eval as pe
2932from jax ._src .pallas import core as pallas_core
3033from jax ._src .pallas .mosaic_gpu import core as gpu_core
3134from jax ._src .pallas .mosaic_gpu import primitives as gpu_primitives
3740zip = util .safe_zip
3841
3942
43+ @jax .tree_util .register_dataclass
4044@dataclasses .dataclass (frozen = True )
4145class BufferedRef :
42- spec : pallas_core .BlockSpec
46+ spec : pallas_core .BlockSpec = dataclasses .field (metadata = {"static" : True })
47+ is_index_invariant : bool = dataclasses .field (metadata = {"static" : True })
4348 gmem_ref : pallas_core .AbstractMemoryRef
4449 smem_ref : pallas_core .AbstractMemoryRef # [num_slots, *spec.block_shape]
4550
46- def compute_gmem_slice (self , grid_indices ) -> tuple [Any , ...]:
51+ def compute_gmem_slice (self , grid_indices ) -> tuple [pl . Slice , ...]:
4752 index_map = self .spec .index_map
4853 assert index_map is not None
4954 return tuple (
50- pl .ds (idx * size , size )
55+ pl .Slice (idx * size , size ) # type: ignore[arg-type]
5156 for idx , size in zip (
5257 index_map (* grid_indices ), self .spec .block_shape # type: ignore[arg-type]
5358 )
@@ -61,16 +66,31 @@ def copy_in(self, slot, grid_indices, barrier_ref):
6166 barrier = barrier_ref .at [slot ],
6267 )
6368
64- def copy_out (self , slot , grid_indices ):
69+ def copy_out (self , slot , grid_indices , predicate = None ):
6570 gmem_slices = self .compute_gmem_slice (grid_indices )
6671 gpu_primitives .copy_smem_to_gmem (
67- self .smem_ref .at [slot ], self .gmem_ref .at [gmem_slices ] # pytype: disable=unsupported-operands
72+ self .smem_ref .at [slot ],
73+ self .gmem_ref .at [gmem_slices ], # pytype: disable=unsupported-operands
74+ predicate = predicate ,
6875 )
6976
7077
71- jax .tree_util .register_dataclass (
72- BufferedRef , data_fields = ["gmem_ref" , "smem_ref" ], meta_fields = ["spec" ]
73- )
78+ def _uses_arguments (
79+ index_map : Callable [..., Any ], num_args : int
80+ ) -> Sequence [bool ]:
81+ jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (
82+ lu .wrap_init (index_map ), (core .ShapedArray ((), jnp .int32 ),) * num_args
83+ )
84+ _ , used_inputs = pe .dce_jaxpr (jaxpr , used_outputs = [True ] * len (jaxpr .outvars ))
85+ return used_inputs
86+
87+
88+ def _is_index_invariant (
89+ spec : pallas_core .BlockSpec , grid : pallas_core .StaticGrid
90+ ) -> bool :
91+ index_map = spec .index_map
92+ assert index_map is not None
93+ return not any (_uses_arguments (index_map , len (grid )))
7494
7595
7696def _inc_grid_by_1 (
@@ -85,6 +105,25 @@ def _inc_grid_by_1(
85105 return tuple (reversed (next_indices ))
86106
87107
108+ # ``pl.Slice`` uses a different pytree encoding, depending on whether the
109+ # start/size are static or dynamic. This leads to pytree structure mismatch
110+ # in the pipeline body. So, we define a different ``Slice`` class below.
111+
112+
113+ @dataclasses .dataclass (frozen = True )
114+ class _Slice :
115+ start : int | jax .Array
116+ size : int | jax .Array
117+
118+ def __eq__ (self , other : _Slice ) -> jax .Array : # type: ignore
119+ return lax .bitwise_and (self .start == other .start , self .size == other .size )
120+
121+
122+ jax .tree_util .register_dataclass (
123+ _Slice , data_fields = ["start" , "size" ], meta_fields = []
124+ )
125+
126+
88127def emit_pipeline (
89128 body ,
90129 * ,
@@ -102,6 +141,16 @@ def emit_pipeline(
102141 max_concurrent_steps = num_steps
103142
104143 def pipeline (* gmem_refs : pallas_core .AbstractMemoryRef ):
144+ for gmem_ref , spec in zip (gmem_refs , it .chain (in_specs , out_specs )):
145+ if any (
146+ spec .block_shape [- idx ] * grid [- idx ] != gmem_ref .shape [- idx ] # type: ignore
147+ for idx in range (1 , len (grid ) + 1 )
148+ ):
149+ raise NotImplementedError (
150+ f"Cannot emit a pipeline over the { grid = } for { gmem_ref } with block"
151+ f" shape { spec .block_shape } ."
152+ )
153+
105154 in_gmem_refs , out_gmem_refs = util .split_list (gmem_refs , [len (in_specs )])
106155 in_smem_refs , out_smem_refs = util .split_list (
107156 map (
@@ -132,13 +181,18 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
132181 def scoped_pipeline (
133182 * , in_gmem_refs , out_gmem_refs , in_smem_refs , out_smem_refs , barrier_ref
134183 ):
135-
136- in_brefs : Sequence [BufferedRef ] = map (
137- BufferedRef , in_specs , in_gmem_refs , in_smem_refs
138- )
139- out_brefs : Sequence [BufferedRef ] = map (
140- BufferedRef , out_specs , out_gmem_refs , out_smem_refs
141- )
184+ in_brefs : Sequence [BufferedRef ] = [
185+ BufferedRef (spec , _is_index_invariant (spec , grid ), gmem_ref , smem_ref )
186+ for spec , gmem_ref , smem_ref in zip (
187+ in_specs , in_gmem_refs , in_smem_refs
188+ )
189+ ]
190+ out_brefs : Sequence [BufferedRef ] = [
191+ BufferedRef (spec , _is_index_invariant (spec , grid ), gmem_ref , smem_ref )
192+ for spec , gmem_ref , smem_ref in zip (
193+ out_specs , out_gmem_refs , out_smem_refs
194+ )
195+ ]
142196
143197 for step , indices in enumerate (
144198 it .islice (it .product (* map (range , grid )), max_concurrent_steps )
@@ -147,10 +201,11 @@ def scoped_pipeline(
147201
148202 def loop_body (step , carry ):
149203 slot = step % max_concurrent_steps
150- indices , fetch_indices = carry
204+ indices , fetch_indices , last_store_slices = carry
151205
152- # Wait for the current GMEM->SMEM copy to complete.
153- gpu_primitives .barrier_wait (barrier_ref .at [slot ])
206+ if in_specs :
207+ # Wait for the current GMEM->SMEM copy to complete.
208+ gpu_primitives .barrier_wait (barrier_ref .at [slot ])
154209 # Wait for the previous output SMEM->GMEM copy to complete.
155210 gpu_primitives .wait_smem_to_gmem (max_concurrent_steps - 1 )
156211
@@ -159,9 +214,34 @@ def loop_body(step, carry):
159214 * (bref .smem_ref .at [slot ] for bref in it .chain (in_brefs , out_brefs ))
160215 )
161216
217+ if not all (bref .is_index_invariant for bref in out_brefs ):
218+ gpu_primitives .commit_smem ()
219+
162220 # Copy the output from SMEM to GMEM.
163- gpu_primitives .commit_smem ()
164- map (lambda bref : bref .copy_out (slot , indices ), out_brefs )
221+ new_store_slices = last_store_slices [:]
222+ for idx , bref in enumerate (out_brefs ):
223+ if bref .is_index_invariant :
224+ assert last_store_slices [idx ] is None
225+ continue
226+ assert last_store_slices [idx ] is not None
227+ new_store_slices [idx ] = tuple (
228+ _Slice (s .start , s .size ) for s in bref .compute_gmem_slice (indices )
229+ )
230+ are_same_slices = map (
231+ lambda old , new : old == new ,
232+ last_store_slices [idx ],
233+ new_store_slices [idx ],
234+ )
235+ slices_changed = ~ functools .reduce (lax .bitwise_and , are_same_slices )
236+ is_last_step = step == num_steps - 1
237+ # TODO(apaszke,slebedev): This still diverges significantly from the
238+ # TPU semantics in that it will move on to the next SMEM output slice
239+ # even if it's not storing the previous one.
240+ bref .copy_out (
241+ slot ,
242+ indices ,
243+ predicate = lax .bitwise_or (slices_changed , is_last_step ),
244+ )
165245
166246 fetch_step = step + max_concurrent_steps
167247 fetch_slot = slot # (x + y) % y == x % y
@@ -174,13 +254,34 @@ def loop_body(step, carry):
174254 lambda : [None ] * len (in_brefs ),
175255 )
176256
177- return _inc_grid_by_1 (indices , grid ), _inc_grid_by_1 (fetch_indices , grid )
257+ return (
258+ _inc_grid_by_1 (indices , grid ),
259+ _inc_grid_by_1 (fetch_indices , grid ),
260+ new_store_slices ,
261+ )
178262
179263 indices = (jnp .asarray (0 , dtype = lax .dtype (0 )),) * len (grid )
180264 fetch_indices = indices
181265 for _ in range (max_concurrent_steps ):
182266 fetch_indices = _inc_grid_by_1 (fetch_indices , grid )
183- lax .fori_loop (0 , num_steps , loop_body , (indices , fetch_indices ))
267+ last_store_slices = [
268+ None
269+ if bref .is_index_invariant
270+ else (_Slice (- 1 , - 1 ),) * len (bref .spec .block_shape )
271+ for bref in out_brefs
272+ ]
273+ last_indices , _ , _ = lax .fori_loop (
274+ 0 , num_steps , loop_body , (indices , fetch_indices , last_store_slices )
275+ )
276+
277+ # Outputs invariant to the sequential axis are never written from inside the
278+ # loop. This is the only place where we store them.
279+ if all (bref .is_index_invariant for bref in out_brefs ):
280+ gpu_primitives .commit_smem ()
281+ last_slot = (num_steps - 1 ) % max_concurrent_steps
282+ for bref in out_brefs :
283+ if bref .is_index_invariant :
284+ bref .copy_out (last_slot , last_indices , predicate = None )
184285
185286 # Finalize the pipeline.
186287 gpu_primitives .wait_smem_to_gmem (0 )
0 commit comments