Skip to content

Commit 651ab18

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Fix elementwise inference with i1s
PiperOrigin-RevId: 703263310
1 parent d782b24 commit 651ab18

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class VectorLayoutInferer {
187187
false_ty.getElementTypeBitWidth() == kNativeBitwidth,
188188
"Only 32-bit select supported");
189189
}
190-
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
190+
if (inferElementwise(&any_op).failed()) {
191191
return failure();
192192
}
193193
} else if (auto op = dyn_cast<arith::ExtUIOp>(any_op)) {
@@ -198,7 +198,7 @@ class VectorLayoutInferer {
198198
auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth()
199199
: op.getIn().getType().getIntOrFloatBitWidth();
200200
if (in_bitwidth == 1) {
201-
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
201+
if (inferElementwise(&any_op).failed()) {
202202
return failure();
203203
}
204204
} else {
@@ -214,7 +214,7 @@ class VectorLayoutInferer {
214214
TPU_CHECK_OP(static_cast<bool>(lhs_ty) == static_cast<bool>(rhs_ty),
215215
"Only one side of cmp is a vector?");
216216
// TODO(tlongeri): Check that TPU generation supports comparison.
217-
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
217+
if (inferElementwise(&any_op).failed()) {
218218
return failure();
219219
}
220220
} else if (auto op = dyn_cast<arith::ConstantOp>(any_op)) {
@@ -1726,7 +1726,7 @@ class VectorLayoutInferer {
17261726
return success();
17271727
}
17281728

1729-
LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) {
1729+
LogicalResult inferElementwise(Operation *op) {
17301730
TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported");
17311731
TPU_CHECK_OP(op->getNumOperands() > 0,
17321732
"elementwise ops with no operands unsupported");
@@ -1735,26 +1735,45 @@ class VectorLayoutInferer {
17351735
std::optional<VectorLayout> out_layout_candidate;
17361736
std::optional<VectorLayout> out_layout;
17371737
SmallVector<std::optional<Layout>, 4> in_layouts;
1738-
int64_t bit_width = -1;
1738+
int64_t bitwidth = -1;
1739+
// Find the bitwidth of the operands/results. They must all be the same
1740+
// except for the case of i1s, which use a "fake" bitwidth for layouts.
1741+
// They can be relayouted (in principle) to any other fake bitwidth, so we
1742+
// don't commit to their bitwidth. See comments in VectorLayout class.
1743+
for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
1744+
if (const VectorType vty = dyn_cast<VectorType>(val.getType())) {
1745+
const int64_t val_bitwidth = vty.getElementTypeBitWidth();
1746+
if (val_bitwidth != 1) {
1747+
if (bitwidth == -1) {
1748+
bitwidth = val_bitwidth;
1749+
} else if (bitwidth != val_bitwidth) {
1750+
return op->emitOpError(
1751+
"Mismatched bitwidth in elementwise for non-i1 "
1752+
"operands/results");
1753+
}
1754+
}
1755+
}
1756+
}
17391757
for (int64_t i = 0; i < op->getNumOperands(); ++i) {
17401758
if (auto vty = dyn_cast<VectorType>(op->getOperand(i).getType())) {
1741-
if (bit_width == -1) {
1742-
bit_width = vty.getElementTypeBitWidth();
1743-
}
1744-
TPU_CHECK_OP(
1745-
!check_bitwidth || bit_width == vty.getElementTypeBitWidth(),
1746-
"Generic elementwise rule only supports operands of same width");
17471759
auto some_layout = getLayout(op->getOperand(i));
17481760
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
17491761
auto &layout = *some_layout;
1750-
// If the input is fully replicated, don't use it to commit to any
1751-
// layout. Replicated values are easy to relayout.
1752-
if (is_fully_replicated(some_layout)) {
1762+
if (bitwidth == -1) {
1763+
// All operands/results are i1s, just commit to the first bitwidth
1764+
DCHECK(!out_layout.has_value());
1765+
bitwidth = layout.bitwidth();
1766+
out_layout = layout;
1767+
in_layouts.push_back(layout);
1768+
} else if (bitwidth != layout.bitwidth()) {
1769+
DCHECK_EQ(vty.getElementTypeBitWidth(), 1);
1770+
in_layouts.push_back(std::nullopt);
1771+
} else if (is_fully_replicated(some_layout)) {
1772+
// If the input is fully replicated, don't use it to commit to any
1773+
// layout. Replicated values are easy to relayout.
17531774
in_layouts.push_back(std::nullopt);
17541775
out_layout_candidate = layout;
1755-
continue;
1756-
}
1757-
if (!out_layout) {
1776+
} else if (!out_layout) {
17581777
// TODO(apaszke): There are probably smarter ways to choose layout.
17591778
out_layout = layout;
17601779
in_layouts.push_back(some_layout);
@@ -1768,8 +1787,9 @@ class VectorLayoutInferer {
17681787
// any replication bits that might have been present in out_layout,
17691788
// since there is no guarantee that the conflicting inputs could
17701789
// even become replicated.
1790+
DCHECK_EQ(out_layout->bitwidth(), bitwidth);
17711791
out_layout =
1772-
VectorLayout(out_layout->bitwidth(),
1792+
VectorLayout(bitwidth,
17731793
{out_layout->offsets()[0].value_or(0),
17741794
out_layout->offsets()[1].value_or(0)},
17751795
out_layout->tiling(), out_layout->implicit_dim());
@@ -1784,9 +1804,6 @@ class VectorLayoutInferer {
17841804
}
17851805
Layout final_out_layout = std::nullopt;
17861806
if (auto out_vty = dyn_cast<VectorType>(op->getResult(0).getType())) {
1787-
TPU_CHECK_OP(
1788-
!check_bitwidth || bit_width == out_vty.getElementTypeBitWidth(),
1789-
"Generic elementwise rule can't change element type width");
17901807
if (out_layout) {
17911808
final_out_layout = *out_layout;
17921809
} else if (out_layout_candidate) {

0 commit comments

Comments
 (0)