Skip to content

Commit 4a41aa0

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Removed unnecessarily strict check in emit_pipeline
PiperOrigin-RevId: 703117465
1 parent 5fe5206 commit 4a41aa0

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,6 @@ def emit_pipeline(
181181
delay_release = 0 # No need to delay anything.
182182

183183
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
184-
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
185-
if any(
186-
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
187-
for idx in range(1, len(grid) + 1)
188-
if spec.block_shape is not None
189-
):
190-
raise NotImplementedError(
191-
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
192-
f" shape {spec.block_shape}."
193-
)
194-
195184
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
196185
in_smem_refs, out_smem_refs = util.split_list(
197186
[

0 commit comments

Comments
 (0)