@@ -623,10 +623,6 @@ def to_layout(self, new_layout: FragmentedLayout):
623623 )
624624
625625 def _pointwise (self , op , * other , output_is_signed : bool | None = None ):
626- is_signed = (
627- output_is_signed if output_is_signed is not None else self .is_signed
628- )
629-
630626 other_arrs = []
631627 for o in other :
632628 if not isinstance (o , FragmentedArray ):
@@ -636,7 +632,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
636632 raise NotImplementedError (o )
637633
638634 o = FragmentedArray .splat (
639- o , shape = self .shape , layout = self .layout , is_signed = is_signed
635+ o , shape = self .shape , layout = self .layout , is_signed = self . is_signed
640636 )
641637
642638 if isinstance (o .layout , WGSplatFragLayout ):
@@ -646,7 +642,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
646642 o .registers .flat [0 ],
647643 shape = self .shape ,
648644 layout = self .layout ,
649- is_signed = is_signed ,
645+ is_signed = self . is_signed ,
650646 )
651647 else :
652648 if self .layout != o .layout :
@@ -659,8 +655,13 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
659655
660656 for idx , reg in np .ndenumerate (self .registers ):
661657 new_regs [idx ] = op (reg , * (o .registers [idx ] for o in other_arrs ))
658+ reg_ty = new_regs .flat [0 ].type
659+ if ir .VectorType .isinstance (reg_ty ):
660+ reg_ty = ir .VectorType (reg_ty ).element_type
661+ if output_is_signed is None and ir .IntegerType .isinstance (reg_ty ):
662+ output_is_signed = self .is_signed
662663 return FragmentedArray (
663- _registers = new_regs , _layout = self .layout , _is_signed = is_signed
664+ _registers = new_regs , _layout = self .layout , _is_signed = output_is_signed
664665 )
665666
666667 def __pos__ (self ):
0 commit comments