Skip to content

Commit 34fe66b

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu] foreach should not try to create an array if it didn't create the registers due to create_array=False.
PiperOrigin-RevId: 700955830
1 parent bdee4c3 commit 34fe66b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,8 @@ def foreach(
12651265
if create_array:
12661266
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)
12671267

1268-
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
1268+
if create_array:
1269+
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
12691270

12701271
def store_untiled(self, ref: ir.Value):
12711272
if not ir.MemRefType.isinstance(ref.type):

0 commit comments

Comments
 (0)