Skip to content

Commit b6566c8

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Fixed unbounded recursion in FragmentedArray._pointwise
PiperOrigin-RevId: 700265616
1 parent 16a5607 commit b6566c8

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,10 @@ def to_layout(self, new_layout: FragmentedLayout):
622622
reg, self.shape, new_layout, is_signed=self.is_signed
623623
)
624624

625-
def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False):
625+
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
626626
# If our layout is a splat, then we should either dispatch to a non-splat
627627
# layout, or broadcast ourselves to the output shape first.
628-
if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout):
628+
if isinstance(self.layout, WGSplatFragLayout):
629629
output_shape = self.shape
630630
for i, o in enumerate(other):
631631
if not isinstance(o, FragmentedArray):
@@ -641,9 +641,10 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_
641641
else:
642642
output_shape = np.broadcast_shapes(output_shape, o.shape)
643643
# If we get here then we haven't found any non-splat layout.
644-
return self.broadcast(output_shape)._pointwise(
645-
op, *other, output_is_signed=output_is_signed, force_no_dispatch=True,
646-
)
644+
if self.shape != output_shape:
645+
return self.broadcast(output_shape)._pointwise(
646+
op, *other, output_is_signed=output_is_signed
647+
)
647648

648649
other_arrs = []
649650
for o in other:

0 commit comments

Comments
 (0)