@@ -187,7 +187,7 @@ class VectorLayoutInferer {
187187 false_ty.getElementTypeBitWidth () == kNativeBitwidth ,
188188 " Only 32-bit select supported" );
189189 }
190- if (inferElementwise (&any_op, /* check_bitwidth= */ false ).failed ()) {
190+ if (inferElementwise (&any_op).failed ()) {
191191 return failure ();
192192 }
193193 } else if (auto op = dyn_cast<arith::ExtUIOp>(any_op)) {
@@ -198,7 +198,7 @@ class VectorLayoutInferer {
198198 auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth ()
199199 : op.getIn ().getType ().getIntOrFloatBitWidth ();
200200 if (in_bitwidth == 1 ) {
201- if (inferElementwise (&any_op, /* check_bitwidth= */ false ).failed ()) {
201+ if (inferElementwise (&any_op).failed ()) {
202202 return failure ();
203203 }
204204 } else {
@@ -214,7 +214,7 @@ class VectorLayoutInferer {
214214 TPU_CHECK_OP (static_cast <bool >(lhs_ty) == static_cast <bool >(rhs_ty),
215215 " Only one side of cmp is a vector?" );
216216 // TODO(tlongeri): Check that TPU generation supports comparison.
217- if (inferElementwise (&any_op, /* check_bitwidth= */ false ).failed ()) {
217+ if (inferElementwise (&any_op).failed ()) {
218218 return failure ();
219219 }
220220 } else if (auto op = dyn_cast<arith::ConstantOp>(any_op)) {
@@ -1726,7 +1726,7 @@ class VectorLayoutInferer {
17261726 return success ();
17271727 }
17281728
1729- LogicalResult inferElementwise (Operation *op, bool check_bitwidth = true ) {
1729+ LogicalResult inferElementwise (Operation *op) {
17301730 TPU_CHECK_OP (op->getNumResults () == 1 , " only one result supported" );
17311731 TPU_CHECK_OP (op->getNumOperands () > 0 ,
17321732 " elementwise ops with no operands unsupported" );
@@ -1735,26 +1735,45 @@ class VectorLayoutInferer {
17351735 std::optional<VectorLayout> out_layout_candidate;
17361736 std::optional<VectorLayout> out_layout;
17371737 SmallVector<std::optional<Layout>, 4 > in_layouts;
1738- int64_t bit_width = -1 ;
1738+ int64_t bitwidth = -1 ;
1739+ // Find the bitwidth of the operands/results. They must all be the same
1740+ // except for the case of i1s, which use a "fake" bitwidth for layouts.
1741+ // They can be relayouted (in principle) to any other fake bitwidth, so we
1742+ // don't commit to their bitwidth. See comments in VectorLayout class.
1743+ for (Value val : llvm::concat<Value>(op->getOperands (), op->getResults ())) {
1744+ if (const VectorType vty = dyn_cast<VectorType>(val.getType ())) {
1745+ const int64_t val_bitwidth = vty.getElementTypeBitWidth ();
1746+ if (val_bitwidth != 1 ) {
1747+ if (bitwidth == -1 ) {
1748+ bitwidth = val_bitwidth;
1749+ } else if (bitwidth != val_bitwidth) {
1750+ return op->emitOpError (
1751+ " Mismatched bitwidth in elementwise for non-i1 "
1752+ " operands/results" );
1753+ }
1754+ }
1755+ }
1756+ }
17391757 for (int64_t i = 0 ; i < op->getNumOperands (); ++i) {
17401758 if (auto vty = dyn_cast<VectorType>(op->getOperand (i).getType ())) {
1741- if (bit_width == -1 ) {
1742- bit_width = vty.getElementTypeBitWidth ();
1743- }
1744- TPU_CHECK_OP (
1745- !check_bitwidth || bit_width == vty.getElementTypeBitWidth (),
1746- " Generic elementwise rule only supports operands of same width" );
17471759 auto some_layout = getLayout (op->getOperand (i));
17481760 TPU_CHECK_OP (some_layout.has_value (), " missing vector layout" );
17491761 auto &layout = *some_layout;
1750- // If the input is fully replicated, don't use it to commit to any
1751- // layout. Replicated values are easy to relayout.
1752- if (is_fully_replicated (some_layout)) {
1762+ if (bitwidth == -1 ) {
1763+ // All operands/results are i1s, just commit to the first bitwidth
1764+ DCHECK (!out_layout.has_value ());
1765+ bitwidth = layout.bitwidth ();
1766+ out_layout = layout;
1767+ in_layouts.push_back (layout);
1768+ } else if (bitwidth != layout.bitwidth ()) {
1769+ DCHECK_EQ (vty.getElementTypeBitWidth (), 1 );
1770+ in_layouts.push_back (std::nullopt );
1771+ } else if (is_fully_replicated (some_layout)) {
1772+ // If the input is fully replicated, don't use it to commit to any
1773+ // layout. Replicated values are easy to relayout.
17531774 in_layouts.push_back (std::nullopt );
17541775 out_layout_candidate = layout;
1755- continue ;
1756- }
1757- if (!out_layout) {
1776+ } else if (!out_layout) {
17581777 // TODO(apaszke): There are probably smarter ways to choose layout.
17591778 out_layout = layout;
17601779 in_layouts.push_back (some_layout);
@@ -1768,8 +1787,9 @@ class VectorLayoutInferer {
17681787 // any replication bits that might have been present in out_layout,
17691788 // since there is no guarantee that the conflicting inputs could
17701789 // even become replicated.
1790+ DCHECK_EQ (out_layout->bitwidth (), bitwidth);
17711791 out_layout =
1772- VectorLayout (out_layout-> bitwidth () ,
1792+ VectorLayout (bitwidth,
17731793 {out_layout->offsets ()[0 ].value_or (0 ),
17741794 out_layout->offsets ()[1 ].value_or (0 )},
17751795 out_layout->tiling (), out_layout->implicit_dim ());
@@ -1784,9 +1804,6 @@ class VectorLayoutInferer {
17841804 }
17851805 Layout final_out_layout = std::nullopt ;
17861806 if (auto out_vty = dyn_cast<VectorType>(op->getResult (0 ).getType ())) {
1787- TPU_CHECK_OP (
1788- !check_bitwidth || bit_width == out_vty.getElementTypeBitWidth (),
1789- " Generic elementwise rule can't change element type width" );
17901807 if (out_layout) {
17911808 final_out_layout = *out_layout;
17921809 } else if (out_layout_candidate) {
0 commit comments