Skip to content

Commit 14ddb81

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Avoid double-predication when async_copy predicate is specified
PiperOrigin-RevId: 700999181
1 parent b09b077 commit 14ddb81

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def async_copy(
356356
arrive: bool | None = None,
357357
uniform: bool = True,
358358
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
359-
predicate: ir.Value | None = None,
359+
predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG.
360360
):
361361
index = ir.IndexType.get()
362362
i16 = ir.IntegerType.get_signless(16)
@@ -504,7 +504,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
504504

505505
uniform_ctx = (
506506
functools.partial(utils.single_thread, per_block=False)
507-
if uniform
507+
if uniform and predicate is None
508508
else contextlib.nullcontext
509509
)
510510

0 commit comments

Comments
 (0)