Skip to content

Commit 9e5edb7

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Support packed type matmul with arbitrary shapes.
This cl removes all the shape constrains in matmul for all types. We only need to mask out subelement on contracting dim. Instead of unpacking data and applying masks, we create a VREG-sized i32 "mask" which contains subelement mask info to logical and with target vreg. Through this way, in order to mask sub-elements, each target vreg only needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + packing). PiperOrigin-RevId: 702480077
1 parent d990dcf commit 9e5edb7

File tree

3 files changed

+73
-89
lines changed

3 files changed

+73
-89
lines changed

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,15 @@ LogicalResult MatmulOp::verify() {
544544
// however, a good start and the recommended place to add more invariants.
545545
const VectorType lhs_ty = getLhs().getType();
546546
const VectorType rhs_ty = getRhs().getType();
547+
const VectorType acc_ty = getAcc().getType();
548+
const VectorType res_ty = getResult().getType();
549+
if (acc_ty != res_ty) {
550+
return emitOpError(
551+
"Not implemented: matmul acc and result have different types");
552+
}
553+
if (acc_ty.getElementTypeBitWidth() != 32) {
554+
return emitOpError("Expected matmul acc to be 32-bit");
555+
}
547556

548557
if (getTransposeLhs()) {
549558
emitOpError(

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,19 +1764,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
17641764
// TODO(tlongeri): This should be part of the tpu::MatmulOp verifier
17651765
TPU_ASSERT_EQ_OP(lhs_shape.size(), 2);
17661766
TPU_ASSERT_EQ_OP(rhs_shape.size(), 2);
1767-
// The code below puts no constraints on the second dimension of both lhs and
1768-
// rhs. However, leading axis of lhs and rhs needs to be a multiple of native
1769-
// tiling for packed types.
1770-
if (layout_lhs.packing() != 1 && lhs_shape[0] % layout_lhs.tiling()[0] != 0) {
1771-
return op.emitOpError(
1772-
"Not implemented: Unsupported LHS shape with padded tiling and "
1773-
"narrower data type");
1774-
}
1775-
if (layout_rhs.packing() != 1 && rhs_shape[0] % layout_rhs.tiling()[0] != 0) {
1776-
return op.emitOpError(
1777-
"Not implemented: Unsupported RHS shape with padded tiling and "
1778-
"narrower data type");
1779-
}
17801767

17811768
const int64_t padded_lhs_rows =
17821769
llvm::alignTo(lhs_shape[0], layout_lhs.tiling()[0]);
@@ -1787,10 +1774,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
17871774
const int64_t padded_rhs_cols =
17881775
llvm::alignTo(rhs_shape[1], layout_rhs.tiling()[1]);
17891776

1790-
if (llvm::alignTo(lhs_shape[0], layout_acc.tiling()[0]) != padded_lhs_rows) {
1791-
return op.emitOpError(
1792-
"Not implemented: Matmul acc requires less padding than lhs");
1793-
}
17941777
FAILUREOR_ASSIGN_OR_RETURN(
17951778
xla::Array<Value> lhs_vregs,
17961779
disassemble(builder, layout_lhs, lhs, ctx.target_shape));
@@ -1801,7 +1784,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
18011784
xla::Array<Value> rhs_vregs,
18021785
disassemble(builder, layout_rhs, rhs, ctx.target_shape));
18031786
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
1804-
TPU_ASSERT_EQ_OP(padded_lhs_rows, acc_vregs.dim(0) * layout_acc.tiling()[0]);
18051787
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);
18061788

18071789
const VectorType i32_vreg_ty =
@@ -1823,27 +1805,64 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
18231805

18241806
// We can also extend this helper function with padding_top and padding_left
18251807
// based on the offsets in vregs.
1826-
// TODO(b/341729764): Support mask subelements.
1808+
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
1809+
op.getLoc(),
1810+
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
1811+
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
1812+
op.getLoc(), DenseElementsAttr::get(
1813+
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
18271814
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
18281815
int64_t padding_right) {
1829-
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
1830-
op.getLoc(),
1831-
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
18321816
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
18331817
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
18341818
// Mask out the bottom.
18351819
if (padding_bottom > 0) {
18361820
// We have limited the row size of LHS and RHS need to be a multiple of
18371821
// native tiling at the beginning of this rule. Therefore, it is safe to
18381822
// bitcast to x32 vreg for masking.
1839-
CHECK_EQ(padding_bottom % packing, 0);
1840-
padding_bottom /= packing;
1841-
auto mask_bottom = getX32VmaskByPaddingEnd(0, padding_bottom);
1823+
int sub_padding = padding_bottom % packing;
1824+
int x32_padding_bottom = padding_bottom / packing;
1825+
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
1826+
// Create an int32 vreg which contains subelement masking and then
1827+
// logical_and with target vreg to mask out the unaligned paddings.
1828+
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
1829+
// [8, 128], then the mask will be:
1830+
//
1831+
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1832+
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1833+
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1834+
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1835+
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1836+
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
1837+
// sublane 6: [0 , 0 , ..., 0 ]
1838+
// sublane 7: [0 , 0 , ..., 0 ]
1839+
//
1840+
// Through this way, in order to mask sub-elements, each target vreg only
1841+
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
1842+
// + packing).
1843+
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
1844+
op.getLoc(),
1845+
DenseElementsAttr::get(
1846+
i32_vreg_ty,
1847+
builder.getI32IntegerAttr(
1848+
0xffffffff >>
1849+
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
1850+
// Insert 0xffffffff above the blended sublane.
1851+
Value sublane_mask = builder.create<arith::SelectOp>(
1852+
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
1853+
partial_sublane_mask);
1854+
// Insert 0 below the blended sublane.
1855+
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
1856+
i32_zeros_vreg);
18421857
for (int64_t i = 0; i < vregs.dim(1); ++i) {
18431858
Value &vreg = vregs({vregs.dim(0) - 1, i});
18441859
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
1845-
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
1846-
i32_zeros_vreg);
1860+
if (sub_padding > 0) {
1861+
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
1862+
} else {
1863+
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
1864+
i32_zeros_vreg);
1865+
}
18471866
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
18481867
}
18491868
}
@@ -1929,8 +1948,9 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
19291948
lhs_zeros_vreg);
19301949
xla::Array<Value> target_rhs_vregs(
19311950
{target_rhs_row_vregs, target_rhs_col_vregs}, rhs_zeros_vreg);
1932-
xla::Array<Value> target_acc_vregs({acc_vregs.dim(0), target_acc_col_vregs},
1933-
acc_zeros_vreg);
1951+
xla::Array<Value> target_acc_vregs(
1952+
{lhs_vregs.dim(0) * layout_lhs.packing(), target_acc_col_vregs},
1953+
acc_zeros_vreg);
19341954
target_lhs_vregs.UpdateSlice(lhs_vregs, {0, 0});
19351955
target_rhs_vregs.UpdateSlice(rhs_vregs, {0, 0});
19361956
target_acc_vregs.UpdateSlice(acc_vregs, {0, 0});

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 15 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -903,66 +903,21 @@ class VectorLayoutInferer {
903903
}
904904

905905
LogicalResult infer(tpu::MatmulOp op) {
906-
auto get_operand_layout =
907-
[&](Value v, llvm::StringRef operand_name,
908-
std::optional<int64_t> major_multiple = std::nullopt,
909-
std::optional<int64_t> minor_multiple =
910-
std::nullopt) -> std::optional<VectorLayout> {
911-
auto layout = getLayout(v);
912-
if (!layout.has_value()) {
913-
op->emitOpError("Internal error: assert failed: Operand ")
914-
<< operand_name << " has no vector layout";
915-
return std::nullopt;
916-
}
917-
auto vty = cast<VectorType>(v.getType());
918-
auto tiling = nativeTiling(vty.getElementTypeBitWidth());
919-
auto shape = vty.getShape().take_back(2);
920-
if (shape[0] % major_multiple.value_or(tiling[0]) != 0 ||
921-
shape[1] % minor_multiple.value_or(tiling[1]) != 0) {
922-
op->emitOpError("Matmul operand ")
923-
<< operand_name << " must have a shape divisible by ("
924-
<< major_multiple.value_or(tiling[0]) << ", "
925-
<< minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0]
926-
<< ", " << shape[1] << ")";
927-
return std::nullopt;
928-
}
929-
// Override tiling to match the native one.
930-
return VectorLayout(layout->bitwidth(), {0, 0}, tiling,
931-
ImplicitDim::kNone);
932-
};
933-
auto res_ty = dyn_cast<VectorType>(op->getResult(0).getType());
934-
TPU_CHECK_OP(res_ty, "only vector results supported");
935-
TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth,
936-
"only 32-bit matmul results supported");
937-
std::array<Layout, 3> in_layout;
938-
CHECK_EQ(op->getNumOperands(), 3);
939-
std::optional<int64_t> lhs_major_multiple;
940-
std::optional<int64_t> rhs_major_multiple;
941-
// We don't restrict the first lhs axis when the data is not packed.
942-
if (cast<VectorType>(op->getOperand(0).getType())
943-
.getElementTypeBitWidth() == kNativeBitwidth) {
944-
lhs_major_multiple = 1;
945-
}
946-
// We don't restrict the first rhs axis when the data is not packed.
947-
if (cast<VectorType>(op->getOperand(1).getType())
948-
.getElementTypeBitWidth() == kNativeBitwidth) {
949-
rhs_major_multiple = 1;
950-
}
951-
in_layout[0] =
952-
get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1);
953-
if (!in_layout[0].has_value()) {
954-
return failure();
955-
}
956-
in_layout[1] =
957-
get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1);
958-
if (!in_layout[1].has_value()) {
959-
return failure();
960-
}
961-
in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1);
962-
if (!in_layout[2].has_value()) {
963-
return failure();
964-
}
965-
setLayout(op, in_layout,
906+
auto lhs_bitwidth = op.getLhs().getType().getElementTypeBitWidth();
907+
auto rhs_bitwidth = op.getRhs().getType().getElementTypeBitWidth();
908+
auto acc_bitwidth = op.getAcc().getType().getElementTypeBitWidth();
909+
auto res_bitwidth = op.getResult().getType().getElementTypeBitWidth();
910+
TPU_CHECK_OP(acc_bitwidth == kNativeBitwidth,
911+
"Expected 32-bit acc in tpu::MatmulOp");
912+
TPU_CHECK_OP(res_bitwidth == kNativeBitwidth,
913+
"Expected 32-bit result in tpu::MatmulOp");
914+
auto lhs_layout = VectorLayout(
915+
lhs_bitwidth, {0, 0}, nativeTiling(lhs_bitwidth), ImplicitDim::kNone);
916+
auto rhs_layout = VectorLayout(
917+
rhs_bitwidth, {0, 0}, nativeTiling(rhs_bitwidth), ImplicitDim::kNone);
918+
auto acc_layout = VectorLayout(
919+
acc_bitwidth, {0, 0}, nativeTiling(acc_bitwidth), ImplicitDim::kNone);
920+
setLayout(op, {lhs_layout, rhs_layout, acc_layout},
966921
VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_,
967922
ImplicitDim::kNone));
968923
return success();

0 commit comments

Comments
 (0)