2929#include " mlir/IR/Attributes.h"
3030#include " mlir/IR/Block.h"
3131#include " mlir/IR/Builders.h"
32- #include " mlir/IR/BuiltinAttributeInterfaces.h"
3332#include " mlir/IR/BuiltinAttributes.h"
3433#include " mlir/IR/BuiltinTypeInterfaces.h"
3534#include " mlir/IR/BuiltinTypes.h"
5251#include " absl/status/status.h"
5352#include " absl/types/span.h"
5453#include " llvm/include/llvm/ADT/APInt.h"
54+ #include " llvm/include/llvm/Support/LogicalResult.h"
5555#include " mlir/include/mlir/Dialect/Arith/IR/Arith.h"
5656#include " mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
5757#include " mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
6464#include " jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
6565#include " jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
6666#include " jaxlib/mosaic/dialect/tpu/util.h"
67+ #include " jaxlib/mosaic/dialect/tpu/vreg_util.h"
6768#include " xla/array.h"
6869#include " xla/layout.h"
6970#include " xla/util.h"
@@ -275,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
275276 CHECK (data_it == data.end ());
276277}
277278
278- FailureOr<TypedAttr> getZeroIntOrFloatAttr (Type ty) {
279- if (isa<FloatType>(ty)) {
280- return TypedAttr (FloatAttr::get (ty, 0 ));
281- }
282- if (isa<IntegerType>(ty)) {
283- return TypedAttr (IntegerAttr::get (ty, 0 ));
284- }
285- return emitError (UnknownLoc::get (ty.getContext ()), " Not implemented: " ) << ty;
286- }
287-
288279FailureOr<int64_t > getIntConst (Value v, bool silent = false ) {
289280 if (auto constant_op = v.getDefiningOp <arith::ConstantOp>()) {
290281 if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue ())) {
@@ -479,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
479470 return argument;
480471}
481472
482- VectorType getNativeVregOrVmaskTypeImpl (
483- Type elem_ty, const int8_t bitwidth,
484- const std::array<int64_t , 2 > target_shape) {
485- if (bitwidth == 32 ) {
486- return VectorType::get (target_shape, elem_ty);
487- }
488- return VectorType::get ({target_shape[0 ], target_shape[1 ], 32 / bitwidth},
489- elem_ty);
490- }
491-
492- VectorType getNativeVregOrVmaskType (Type elem_ty, const int8_t layout_bitwidth,
493- const std::array<int64_t , 2 > target_shape) {
494- int8_t bitwidth = elem_ty.getIntOrFloatBitWidth ();
495- if (bitwidth == 1 ) {
496- bitwidth = layout_bitwidth;
497- } else {
498- CHECK_EQ (bitwidth, layout_bitwidth);
499- }
500- return getNativeVregOrVmaskTypeImpl (elem_ty, bitwidth, target_shape);
501- }
502-
503- VectorType getNativeVregType (Type elem_ty,
504- const std::array<int64_t , 2 > target_shape) {
505- return getNativeVregOrVmaskTypeImpl (elem_ty, elem_ty.getIntOrFloatBitWidth (),
506- target_shape);
507- }
508-
509473// Masks all values outside of bounds.
510474//
511475// Arguments:
@@ -518,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
518482// Returns:
519483// An MLIR value of the same type as the value argument, with all entries
520484// outside of bounds replaced by neutral.
521- FailureOr<Value> maskOOB (RewriteContext &ctx, OpBuilder &builder,
485+ FailureOr<Value> maskOOB (RewriteContext &ctx, ImplicitLocOpBuilder &builder,
522486 TypedValue<VectorType> value,
523487 const VRegDataBounds &bounds,
524488 const Attribute neutral) {
@@ -542,9 +506,7 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
542506 value.getLoc (),
543507 VectorType::get (native_vreg_ty.getShape (), builder.getI1Type ()), mask);
544508 }
545- auto neutral_vec = builder.create <arith::ConstantOp>(
546- value.getLoc (), native_vreg_ty,
547- DenseElementsAttr::get (native_vreg_ty, neutral));
509+ Value neutral_vec = getFullVector (builder, native_vreg_ty, neutral);
548510 return builder
549511 .create <arith::SelectOp>(value.getLoc (), mask, value, neutral_vec)
550512 .getResult ();
@@ -1863,126 +1825,28 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
18631825 TPU_ASSERT_EQ_OP (padded_lhs_rows, lhs_vregs.dim (0 ) * layout_lhs.tiling ()[0 ]);
18641826 TPU_ASSERT_EQ_OP (padded_rhs_rows, rhs_vregs.dim (0 ) * layout_rhs.tiling ()[0 ]);
18651827
1866- const VectorType i32_vreg_ty =
1867- getNativeVregType (builder.getI32Type (), ctx.target_shape );
1868- auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
1869- CHECK (dim == 0 || dim == 1 );
1870- CHECK (padding >= 0 && padding <= ctx.target_shape [dim]);
1871- return cast<TypedValue<VectorType>>(
1872- builder
1873- .create <arith::CmpIOp>(
1874- arith::CmpIPredicate::slt,
1875- builder.create <tpu::IotaOp>(i32_vreg_ty,
1876- builder.getI32IntegerAttr (dim)),
1877- builder.create <arith::ConstantOp>(DenseElementsAttr::get (
1878- i32_vreg_ty, builder.getI32IntegerAttr (
1879- ctx.target_shape [dim] - padding))))
1880- .getResult ());
1881- };
1882-
1883- // We can also extend this helper function with padding_top and padding_left
1884- // based on the offsets in vregs.
1885- const Value i32_zeros_vreg = builder.create <arith::ConstantOp>(
1886- op.getLoc (),
1887- DenseElementsAttr::get (i32_vreg_ty, builder.getI32IntegerAttr (0 )));
1888- const Value i32_max_vreg = builder.create <arith::ConstantOp>(
1889- op.getLoc (), DenseElementsAttr::get (
1890- i32_vreg_ty, builder.getI32IntegerAttr (0xffffffff )));
1891- auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
1892- int64_t padding_right) {
1893- auto vreg_ty = cast<VectorType>(vregs.begin ()->getType ());
1894- int packing = vreg_ty.getRank () > 2 ? vreg_ty.getShape ()[2 ] : 1 ;
1895- // Mask out the bottom.
1896- if (padding_bottom > 0 ) {
1897- // We have limited the row size of LHS and RHS need to be a multiple of
1898- // native tiling at the beginning of this rule. Therefore, it is safe to
1899- // bitcast to x32 vreg for masking.
1900- int sub_padding = padding_bottom % packing;
1901- int x32_padding_bottom = padding_bottom / packing;
1902- auto mask_bottom = getX32VmaskByPaddingEnd (0 , x32_padding_bottom);
1903- // Create an int32 vreg which contains subelement masking and then
1904- // logical_and with target vreg to mask out the unaligned paddings.
1905- // Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
1906- // [8, 128], then the mask will be:
1907- //
1908- // sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1909- // sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1910- // sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1911- // sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1912- // sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1913- // sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
1914- // sublane 6: [0 , 0 , ..., 0 ]
1915- // sublane 7: [0 , 0 , ..., 0 ]
1916- //
1917- // Through this way, in order to mask sub-elements, each target vreg only
1918- // needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
1919- // + packing).
1920- Value partial_sublane_mask = builder.create <arith::ConstantOp>(
1921- op.getLoc (),
1922- DenseElementsAttr::get (
1923- i32_vreg_ty,
1924- builder.getI32IntegerAttr (
1925- 0xffffffff >>
1926- (sub_padding * vreg_ty.getElementTypeBitWidth ()))));
1927- // Insert 0xffffffff above the blended sublane.
1928- Value sublane_mask = builder.create <arith::SelectOp>(
1929- getX32VmaskByPaddingEnd (0 , x32_padding_bottom + 1 ), i32_max_vreg,
1930- partial_sublane_mask);
1931- // Insert 0 below the blended sublane.
1932- sublane_mask = builder.create <arith::SelectOp>(mask_bottom, sublane_mask,
1933- i32_zeros_vreg);
1934- for (int64_t i = 0 ; i < vregs.dim (1 ); ++i) {
1935- Value &vreg = vregs ({vregs.dim (0 ) - 1 , i});
1936- Value i32_vreg = builder.create <tpu::BitcastVregOp>(i32_vreg_ty, vreg);
1937- if (sub_padding > 0 ) {
1938- i32_vreg = builder.create <arith::AndIOp>(i32_vreg, sublane_mask);
1939- } else {
1940- i32_vreg = builder.create <arith::SelectOp>(mask_bottom, i32_vreg,
1941- i32_zeros_vreg);
1942- }
1943- vreg = builder.create <tpu::BitcastVregOp>(vreg_ty, i32_vreg);
1944- }
1945- }
1946- // Mask out the right.
1947- if (padding_right > 0 ) {
1948- auto mask_right = getX32VmaskByPaddingEnd (1 , padding_right);
1949- for (int64_t i = 0 ; i < vregs.dim (0 ); ++i) {
1950- Value &vreg = vregs ({i, vregs.dim (1 ) - 1 });
1951- Value i32_vreg = builder.create <tpu::BitcastVregOp>(i32_vreg_ty, vreg);
1952- i32_vreg = builder.create <arith::SelectOp>(mask_right, i32_vreg,
1953- i32_zeros_vreg);
1954- vreg = builder.create <tpu::BitcastVregOp>(vreg_ty, i32_vreg);
1955- }
1956- }
1957- };
1958-
1959- // Create a vreg filled with zeros.
1960- auto getZerosVergLike =
1961- [&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
1962- const VectorType vreg_type = cast<VectorType>(vreg.getType ());
1963- FAILUREOR_ASSIGN_OR_RETURN (
1964- const Attribute zero_attr,
1965- getZeroIntOrFloatAttr (vreg_type.getElementType ()));
1966- return cast<TypedValue<VectorType>>(
1967- builder
1968- .create <arith::ConstantOp>(
1969- op.getLoc (), DenseElementsAttr::get (vreg_type, zero_attr))
1970- .getResult ());
1971- };
1972-
1973- FAILUREOR_ASSIGN_OR_RETURN (auto lhs_zeros_vreg,
1974- getZerosVergLike (*lhs_vregs.begin ()));
1975- FAILUREOR_ASSIGN_OR_RETURN (auto rhs_zeros_vreg,
1976- getZerosVergLike (*rhs_vregs.begin ()));
1977- FAILUREOR_ASSIGN_OR_RETURN (auto acc_zeros_vreg,
1978- getZerosVergLike (*acc_vregs.begin ()));
1828+ auto lhs_zeros_vreg =
1829+ getZerosVector (builder, cast<VectorType>(lhs_vregs.begin ()->getType ()));
1830+ auto rhs_zeros_vreg =
1831+ getZerosVector (builder, cast<VectorType>(rhs_vregs.begin ()->getType ()));
1832+ auto acc_zeros_vreg =
1833+ getZerosVector (builder, cast<VectorType>(acc_vregs.begin ()->getType ()));
19791834
19801835 // Only mask out the paddings on contracting dim of LHS and RHS.
1981- maskVregs (lhs_vregs, 0 , padded_lhs_cols - lhs_shape[1 ]);
1836+ RETURN_IF_FAILED (
1837+ maskNativeTilingVregs (builder, lhs_vregs, ctx.target_shape ,
1838+ /* padding_bottom=*/ 0 ,
1839+ /* padding_right=*/ padded_lhs_cols - lhs_shape[1 ]));
19821840 if (transpose_rhs) {
1983- maskVregs (rhs_vregs, 0 , padded_rhs_cols - rhs_shape[1 ]);
1841+ RETURN_IF_FAILED (maskNativeTilingVregs (
1842+ builder, rhs_vregs, ctx.target_shape ,
1843+ /* padding_bottom=*/ 0 ,
1844+ /* padding_right=*/ padded_rhs_cols - rhs_shape[1 ]));
19841845 } else {
1985- maskVregs (rhs_vregs, padded_rhs_rows - rhs_shape[0 ], 0 );
1846+ RETURN_IF_FAILED (
1847+ maskNativeTilingVregs (builder, rhs_vregs, ctx.target_shape ,
1848+ /* padding_bottom=*/ padded_rhs_rows - rhs_shape[0 ],
1849+ /* padding_right=*/ 0 ));
19861850 }
19871851
19881852 // At this point, all paddings on vregs are masked out. For now, we
@@ -2875,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
28752739 native_vreg_ty,
28762740 /* dimension =*/ builder.getI32IntegerAttr (1 ));
28772741 for (int64_t i = 0 ; i < num_tiles; ++i) {
2878- auto offset = builder.create <arith::ConstantOp>(
2879- native_vreg_ty,
2880- DenseElementsAttr::get (
2881- native_vreg_ty,
2882- IntegerAttr::get (vty.getElementType (),
2883- i * *(native_vreg_ty.getShape ().end () - 1 ))));
2742+ Value offset = getFullVector (
2743+ builder, native_vreg_ty,
2744+ IntegerAttr::get (vty.getElementType (),
2745+ i * *(native_vreg_ty.getShape ().end () - 1 )));
28842746 tiles[i] = builder.create <arith::AddIOp>(vreg_iota, offset);
28852747 }
28862748 xla::Array<Value> broadcasted_tiles (tile_array_shape);
@@ -2902,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
29022764 native_vreg_ty,
29032765 /* dimension =*/ builder.getI32IntegerAttr (0 ));
29042766 for (int64_t i = 0 ; i < num_tiles; ++i) {
2905- auto offset = builder.create <arith::ConstantOp>(
2906- native_vreg_ty,
2907- DenseElementsAttr::get (
2908- native_vreg_ty,
2909- IntegerAttr::get (vty.getElementType (),
2910- i * *(native_vreg_ty.getShape ().end () - 2 ))));
2767+ Value offset = getFullVector (
2768+ builder, native_vreg_ty,
2769+ IntegerAttr::get (vty.getElementType (),
2770+ i * *(native_vreg_ty.getShape ().end () - 2 )));
29112771 tiles[i] = builder.create <arith::AddIOp>(vreg_iota, offset);
29122772 }
29132773 xla::Array<Value> broadcasted_tiles (tile_array_shape);
@@ -2924,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
29242784 SmallVector<Value> tiles;
29252785 tiles.reserve (vty.getDimSize (*dimension));
29262786 for (int64_t i = 0 ; i < vty.getDimSize (*dimension); ++i) {
2927- tiles.push_back (builder.create <arith::ConstantOp>(
2928- native_vreg_ty,
2929- DenseElementsAttr::get (native_vreg_ty,
2930- IntegerAttr::get (vty.getElementType (), i))));
2787+ tiles.push_back (getFullVector (builder, native_vreg_ty,
2788+ IntegerAttr::get (vty.getElementType (), i)));
29312789 }
29322790 xla::Array<Value> out_tiles (tile_array_shape);
29332791 out_tiles.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
@@ -3516,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
35163374 const int64_t offset = *offsets_in[1 ];
35173375 const int64_t lane_offset = offset % ctx.target_shape [1 ];
35183376 const int64_t tile_offset = offset / ctx.target_shape [1 ];
3519- const auto idx_ty =
3520- VectorType::get (ctx.target_shape , builder.getI32Type ());
3521- auto lane_offset_cst = builder.create <arith::ConstantOp>(
3522- broadcast_op.getLoc (), idx_ty,
3523- DenseElementsAttr::get (idx_ty,
3524- builder.getI32IntegerAttr (lane_offset)));
3377+ Value lane_offset_cst = getFullVector (
3378+ builder, getNativeVregType (builder.getI32Type (), ctx.target_shape ),
3379+ builder.getI32IntegerAttr (lane_offset));
35253380 DenseI32ArrayAttr sublane_pattern;
35263381 if (num_tiles != 1 ) {
35273382 SmallVector<int32_t > pattern;
@@ -3581,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
35813436 getNativeVregType (src_i32.getType (), ctx.target_shape );
35823437 auto tile_i32 =
35833438 builder.create <vector::BroadcastOp>(native_vreg_ty, src_i32);
3584- auto zeros = builder.create <arith::ConstantOp>(
3585- broadcast_op.getLoc (), tile_i32.getType (),
3586- DenseElementsAttr::get (tile_i32.getType (),
3587- builder.getI32IntegerAttr (0 )));
3439+ Value zeros = getZerosVector (builder, tile_i32.getType ());
35883440 auto tile =
35893441 builder.create <arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
35903442 .getResult ();
0 commit comments