@@ -1764,19 +1764,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
17641764 // TODO(tlongeri): This should be part of the tpu::MatmulOp verifier
17651765 TPU_ASSERT_EQ_OP (lhs_shape.size (), 2 );
17661766 TPU_ASSERT_EQ_OP (rhs_shape.size (), 2 );
1767- // The code below puts no constraints on the second dimension of both lhs and
1768- // rhs. However, leading axis of lhs and rhs needs to be a multiple of native
1769- // tiling for packed types.
1770- if (layout_lhs.packing () != 1 && lhs_shape[0 ] % layout_lhs.tiling ()[0 ] != 0 ) {
1771- return op.emitOpError (
1772- " Not implemented: Unsupported LHS shape with padded tiling and "
1773- " narrower data type" );
1774- }
1775- if (layout_rhs.packing () != 1 && rhs_shape[0 ] % layout_rhs.tiling ()[0 ] != 0 ) {
1776- return op.emitOpError (
1777- " Not implemented: Unsupported RHS shape with padded tiling and "
1778- " narrower data type" );
1779- }
17801767
17811768 const int64_t padded_lhs_rows =
17821769 llvm::alignTo (lhs_shape[0 ], layout_lhs.tiling ()[0 ]);
@@ -1787,10 +1774,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
17871774 const int64_t padded_rhs_cols =
17881775 llvm::alignTo (rhs_shape[1 ], layout_rhs.tiling ()[1 ]);
17891776
1790- if (llvm::alignTo (lhs_shape[0 ], layout_acc.tiling ()[0 ]) != padded_lhs_rows) {
1791- return op.emitOpError (
1792- " Not implemented: Matmul acc requires less padding than lhs" );
1793- }
17941777 FAILUREOR_ASSIGN_OR_RETURN (
17951778 xla::Array<Value> lhs_vregs,
17961779 disassemble (builder, layout_lhs, lhs, ctx.target_shape ));
@@ -1801,7 +1784,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
18011784 xla::Array<Value> rhs_vregs,
18021785 disassemble (builder, layout_rhs, rhs, ctx.target_shape ));
18031786 TPU_ASSERT_EQ_OP (padded_lhs_rows, lhs_vregs.dim (0 ) * layout_lhs.tiling ()[0 ]);
1804- TPU_ASSERT_EQ_OP (padded_lhs_rows, acc_vregs.dim (0 ) * layout_acc.tiling ()[0 ]);
18051787 TPU_ASSERT_EQ_OP (padded_rhs_rows, rhs_vregs.dim (0 ) * layout_rhs.tiling ()[0 ]);
18061788
18071789 const VectorType i32_vreg_ty =
@@ -1823,27 +1805,64 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
18231805
18241806 // We can also extend this helper function with padding_top and padding_left
18251807 // based on the offsets in vregs.
1826- // TODO(b/341729764): Support mask subelements.
1808+ const Value i32_zeros_vreg = builder.create <arith::ConstantOp>(
1809+ op.getLoc (),
1810+ DenseElementsAttr::get (i32_vreg_ty, builder.getI32IntegerAttr (0 )));
1811+ const Value i32_max_vreg = builder.create <arith::ConstantOp>(
1812+ op.getLoc (), DenseElementsAttr::get (
1813+ i32_vreg_ty, builder.getI32IntegerAttr (0xffffffff )));
18271814 auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
18281815 int64_t padding_right) {
1829- const Value i32_zeros_vreg = builder.create <arith::ConstantOp>(
1830- op.getLoc (),
1831- DenseElementsAttr::get (i32_vreg_ty, builder.getI32IntegerAttr (0 )));
18321816 auto vreg_ty = cast<VectorType>(vregs.begin ()->getType ());
18331817 int packing = vreg_ty.getRank () > 2 ? vreg_ty.getShape ()[2 ] : 1 ;
18341818 // Mask out the bottom.
18351819 if (padding_bottom > 0 ) {
18361820 // We have limited the row size of LHS and RHS need to be a multiple of
18371821 // native tiling at the beginning of this rule. Therefore, it is safe to
18381822 // bitcast to x32 vreg for masking.
1839- CHECK_EQ (padding_bottom % packing, 0 );
1840- padding_bottom /= packing;
1841- auto mask_bottom = getX32VmaskByPaddingEnd (0 , padding_bottom);
1823+ int sub_padding = padding_bottom % packing;
1824+ int x32_padding_bottom = padding_bottom / packing;
1825+ auto mask_bottom = getX32VmaskByPaddingEnd (0 , x32_padding_bottom);
1826+ // Create an int32 vreg which contains subelement masking and then
1827+ // logical_and with target vreg to mask out the unaligned paddings.
1828+ // Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
1829+ // [8, 128], then the mask will be:
1830+ //
1831+ // sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1832+ // sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1833+ // sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1834+ // sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1835+ // sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1836+ // sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
1837+ // sublane 6: [0 , 0 , ..., 0 ]
1838+ // sublane 7: [0 , 0 , ..., 0 ]
1839+ //
1840+ // Through this way, in order to mask sub-elements, each target vreg only
1841+ // needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
1842+ // + packing).
1843+ Value partial_sublane_mask = builder.create <arith::ConstantOp>(
1844+ op.getLoc (),
1845+ DenseElementsAttr::get (
1846+ i32_vreg_ty,
1847+ builder.getI32IntegerAttr (
1848+ 0xffffffff >>
1849+ (sub_padding * vreg_ty.getElementTypeBitWidth ()))));
1850+ // Insert 0xffffffff above the blended sublane.
1851+ Value sublane_mask = builder.create <arith::SelectOp>(
1852+ getX32VmaskByPaddingEnd (0 , x32_padding_bottom + 1 ), i32_max_vreg,
1853+ partial_sublane_mask);
1854+ // Insert 0 below the blended sublane.
1855+ sublane_mask = builder.create <arith::SelectOp>(mask_bottom, sublane_mask,
1856+ i32_zeros_vreg);
18421857 for (int64_t i = 0 ; i < vregs.dim (1 ); ++i) {
18431858 Value &vreg = vregs ({vregs.dim (0 ) - 1 , i});
18441859 Value i32_vreg = builder.create <tpu::BitcastVregOp>(i32_vreg_ty, vreg);
1845- i32_vreg = builder.create <arith::SelectOp>(mask_bottom, i32_vreg,
1846- i32_zeros_vreg);
1860+ if (sub_padding > 0 ) {
1861+ i32_vreg = builder.create <arith::AndIOp>(i32_vreg, sublane_mask);
1862+ } else {
1863+ i32_vreg = builder.create <arith::SelectOp>(mask_bottom, i32_vreg,
1864+ i32_zeros_vreg);
1865+ }
18471866 vreg = builder.create <tpu::BitcastVregOp>(vreg_ty, i32_vreg);
18481867 }
18491868 }
@@ -1929,8 +1948,9 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
19291948 lhs_zeros_vreg);
19301949 xla::Array<Value> target_rhs_vregs (
19311950 {target_rhs_row_vregs, target_rhs_col_vregs}, rhs_zeros_vreg);
1932- xla::Array<Value> target_acc_vregs ({acc_vregs.dim (0 ), target_acc_col_vregs},
1933- acc_zeros_vreg);
1951+ xla::Array<Value> target_acc_vregs (
1952+ {lhs_vregs.dim (0 ) * layout_lhs.packing (), target_acc_col_vregs},
1953+ acc_zeros_vreg);
19341954 target_lhs_vregs.UpdateSlice (lhs_vregs, {0 , 0 });
19351955 target_rhs_vregs.UpdateSlice (rhs_vregs, {0 , 0 });
19361956 target_acc_vregs.UpdateSlice (acc_vregs, {0 , 0 });
0 commit comments