@@ -622,10 +622,10 @@ def to_layout(self, new_layout: FragmentedLayout):
622622 reg , self .shape , new_layout , is_signed = self .is_signed
623623 )
624624
625- def _pointwise (self , op , * other , output_is_signed : bool | None = None , force_no_dispatch = False ):
625+ def _pointwise (self , op , * other , output_is_signed : bool | None = None ):
626626 # If our layout is a splat, then we should either dispatch to a non-splat
627627 # layout, or broadcast ourselves to the output shape first.
628- if not force_no_dispatch and isinstance (self .layout , WGSplatFragLayout ):
628+ if isinstance (self .layout , WGSplatFragLayout ):
629629 output_shape = self .shape
630630 for i , o in enumerate (other ):
631631 if not isinstance (o , FragmentedArray ):
@@ -641,9 +641,10 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_
641641 else :
642642 output_shape = np .broadcast_shapes (output_shape , o .shape )
643643 # If we get here then we haven't found any non-splat layout.
644- return self .broadcast (output_shape )._pointwise (
645- op , * other , output_is_signed = output_is_signed , force_no_dispatch = True ,
646- )
644+ if self .shape != output_shape :
645+ return self .broadcast (output_shape )._pointwise (
646+ op , * other , output_is_signed = output_is_signed
647+ )
647648
648649 other_arrs = []
649650 for o in other :
0 commit comments