Skip to content

Commit 04a6652

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic] Fix handling of i1 splat constants
PiperOrigin-RevId: 694248723
1 parent 3b2e4a1 commit 04a6652

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3190,8 +3190,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
31903190
}
31913191
const VectorLayout &layout_out = *layouts_out.front();
31923192
DenseElementsAttr value = cast<DenseElementsAttr>(constant_op.getValue());
3193-
const VectorType target_vty =
3194-
getNativeVregType(vty.getElementType(), ctx.target_shape);
3193+
const VectorType target_vty = getNativeVregOrVmaskType(
3194+
vty.getElementType(), layout_out.bitwidth(), ctx.target_shape);
31953195
if (value.isSplat()) {
31963196
if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) {
31973197
return op.emitOpError(

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ class VectorLayoutInferer {
365365
TPU_CHECK_OP(ty.getRank() > 0, "rank 0 vectors unsupported");
366366
TPU_CHECK_OP(elems, "expected vector constants to use DenseElementsAttr");
367367
auto bitwidth = ty.getElementTypeBitWidth();
368+
if (bitwidth == 1) {
369+
// i1 is a special case where the layout bitwidth can be different from
370+
// the element bitwidth, see comment in VectorLayout class
371+
bitwidth = kNativeBitwidth;
372+
}
368373
if (elems.isSplat()) {
369374
if (ty.getRank() == 1) {
370375
// Here, we choose to lay out along lanes arbitrarily. It would be

0 commit comments

Comments
 (0)