Skip to content

Commit 1afb05e

Browse files
petebuGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise.
Only propagate signedness from operands when the output type of `op` is an `ir.IntegerType`. PiperOrigin-RevId: 698324596
1 parent ae46b75 commit 1afb05e

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,6 @@ def to_layout(self, new_layout: FragmentedLayout):
623623
)
624624

625625
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
626-
is_signed = (
627-
output_is_signed if output_is_signed is not None else self.is_signed
628-
)
629-
630626
other_arrs = []
631627
for o in other:
632628
if not isinstance(o, FragmentedArray):
@@ -636,7 +632,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
636632
raise NotImplementedError(o)
637633

638634
o = FragmentedArray.splat(
639-
o, shape=self.shape, layout=self.layout, is_signed=is_signed
635+
o, shape=self.shape, layout=self.layout, is_signed=self.is_signed
640636
)
641637

642638
if isinstance(o.layout, WGSplatFragLayout):
@@ -646,7 +642,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
646642
o.registers.flat[0],
647643
shape=self.shape,
648644
layout=self.layout,
649-
is_signed=is_signed,
645+
is_signed=self.is_signed,
650646
)
651647
else:
652648
if self.layout != o.layout:
@@ -659,8 +655,13 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
659655

660656
for idx, reg in np.ndenumerate(self.registers):
661657
new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs))
658+
reg_ty = new_regs.flat[0].type
659+
if ir.VectorType.isinstance(reg_ty):
660+
reg_ty = ir.VectorType(reg_ty).element_type
661+
if output_is_signed is None and ir.IntegerType.isinstance(reg_ty):
662+
output_is_signed = self.is_signed
662663
return FragmentedArray(
663-
_registers=new_regs, _layout=self.layout, _is_signed=is_signed
664+
_registers=new_regs, _layout=self.layout, _is_signed=output_is_signed
664665
)
665666

666667
def __pos__(self):

0 commit comments

Comments
 (0)