Skip to content

Commit 914600a

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Simplify logic for pointwise splat operands
The previous version of the code was too complicated and failed to account for the fact that in an op that broadcasts there does not necessarily exist and operand that has the output shape. Reading through the code now, it's a bit weird that we allow implicit broadcasting of operands with splat layouts, but not any other operands. But I guess that's a thing to implement later. PiperOrigin-RevId: 699983045
1 parent 69e3f0d commit 914600a

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -623,37 +623,27 @@ def to_layout(self, new_layout: FragmentedLayout):
623623
)
624624

625625
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
626+
# If our layout is a splat, then we should either dispatch to a non-splat
627+
# layout, or broadcast ourselves to the output shape first.
626628
if isinstance(self.layout, WGSplatFragLayout):
627-
# Find either the largest operand or an operand that has a
628-
# concrete layout base the layout computation of that.
629-
widest_idx = None
629+
output_shape = self.shape
630630
for i, o in enumerate(other):
631631
if not isinstance(o, FragmentedArray):
632632
continue
633633
elif not isinstance(o.layout, WGSplatFragLayout):
634-
widest_idx = i
635-
break
636-
elif not o.layout.can_broadcast_to(self.layout.shape):
637-
# Note: equal shapes can be broadcast to each other. Using
638-
# the negation we make sure to only consider strictly larger
639-
# shapes so that we don't end up ping ponging between equal
640-
# shapes.
641-
widest_idx = i
642-
643-
if widest_idx is not None:
644-
# We need to retain the order of arguments that the op
645-
# expects.
646-
def _op(wide_o, self_o, *args):
647-
pre_wide = args[:widest_idx - 1]
648-
post_wide = args[widest_idx - 1:]
649-
return op(self_o, *pre_wide, wide_o, *post_wide)
650-
return other[widest_idx]._pointwise(
651-
_op,
652-
self,
653-
*other[:widest_idx],
654-
*other[widest_idx + 1:],
655-
output_is_signed=output_is_signed,
656-
)
634+
return o._pointwise(
635+
lambda o, *args: op(*args[:i], o, *args[i:]),
636+
self,
637+
*other[:i],
638+
*other[i + 1 :],
639+
output_is_signed=output_is_signed,
640+
)
641+
else:
642+
output_shape = np.broadcast_shapes(output_shape, o.shape)
643+
# If we get here then we haven't found any non-splat layout.
644+
return self.broadcast(output_shape)._pointwise(
645+
op, *other, output_is_signed=output_is_signed
646+
)
657647

658648
other_arrs = []
659649
for o in other:

0 commit comments

Comments
 (0)