Skip to content

Commit 57b2154

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] NFC: Pull out vreg related functions to util.
These functions are related to vreg manipulation and are used in different rules. PiperOrigin-RevId: 711484002
1 parent df36c29 commit 57b2154

File tree

5 files changed

+567
-184
lines changed

5 files changed

+567
-184
lines changed

jaxlib/mosaic/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cc_library(
4343
"dialect/tpu/tpu_dialect.cc",
4444
"dialect/tpu/tpu_ops.cc",
4545
"dialect/tpu/util.cc",
46+
"dialect/tpu/vreg_util.cc",
4647
":extension_srcs",
4748
] + glob([
4849
"dialect/tpu/transforms/*.cc",
@@ -51,6 +52,7 @@ cc_library(
5152
"dialect/tpu/layout.h",
5253
"dialect/tpu/tpu_dialect.h",
5354
"dialect/tpu/util.h",
55+
"dialect/tpu/vreg_util.h",
5456
] + glob([
5557
"dialect/tpu/transforms/*.h",
5658
]),
@@ -232,6 +234,19 @@ cc_library(
232234
alwayslink = True,
233235
)
234236

237+
cc_test(
238+
name = "vreg_util_test",
239+
srcs = ["dialect/tpu/vreg_util_test.cc"],
240+
deps = [
241+
":tpu_dialect",
242+
"//testing/base/public:gunit_main",
243+
"@llvm-project//mlir:ArithDialect",
244+
"@llvm-project//mlir:IR",
245+
"@llvm-project//mlir:Support",
246+
"@llvm-project//mlir:VectorDialect",
247+
],
248+
)
249+
235250
filegroup(
236251
name = "extension_srcs",
237252
srcs = [

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 36 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "mlir/IR/Attributes.h"
3030
#include "mlir/IR/Block.h"
3131
#include "mlir/IR/Builders.h"
32-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
3332
#include "mlir/IR/BuiltinAttributes.h"
3433
#include "mlir/IR/BuiltinTypeInterfaces.h"
3534
#include "mlir/IR/BuiltinTypes.h"
@@ -52,6 +51,7 @@
5251
#include "absl/status/status.h"
5352
#include "absl/types/span.h"
5453
#include "llvm/include/llvm/ADT/APInt.h"
54+
#include "llvm/include/llvm/Support/LogicalResult.h"
5555
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
5656
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
5757
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
@@ -64,6 +64,7 @@
6464
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
6565
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
6666
#include "jaxlib/mosaic/dialect/tpu/util.h"
67+
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
6768
#include "xla/array.h"
6869
#include "xla/layout.h"
6970
#include "xla/util.h"
@@ -275,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
275276
CHECK(data_it == data.end());
276277
}
277278

278-
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
279-
if (isa<FloatType>(ty)) {
280-
return TypedAttr(FloatAttr::get(ty, 0));
281-
}
282-
if (isa<IntegerType>(ty)) {
283-
return TypedAttr(IntegerAttr::get(ty, 0));
284-
}
285-
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
286-
}
287-
288279
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
289280
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
290281
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
@@ -479,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
479470
return argument;
480471
}
481472

482-
VectorType getNativeVregOrVmaskTypeImpl(
483-
Type elem_ty, const int8_t bitwidth,
484-
const std::array<int64_t, 2> target_shape) {
485-
if (bitwidth == 32) {
486-
return VectorType::get(target_shape, elem_ty);
487-
}
488-
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
489-
elem_ty);
490-
}
491-
492-
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
493-
const std::array<int64_t, 2> target_shape) {
494-
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
495-
if (bitwidth == 1) {
496-
bitwidth = layout_bitwidth;
497-
} else {
498-
CHECK_EQ(bitwidth, layout_bitwidth);
499-
}
500-
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
501-
}
502-
503-
VectorType getNativeVregType(Type elem_ty,
504-
const std::array<int64_t, 2> target_shape) {
505-
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
506-
target_shape);
507-
}
508-
509473
// Masks all values outside of bounds.
510474
//
511475
// Arguments:
@@ -518,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
518482
// Returns:
519483
// An MLIR value of the same type as the value argument, with all entries
520484
// outside of bounds replaced by neutral.
521-
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
485+
FailureOr<Value> maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder,
522486
TypedValue<VectorType> value,
523487
const VRegDataBounds &bounds,
524488
const Attribute neutral) {
@@ -542,9 +506,7 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
542506
value.getLoc(),
543507
VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask);
544508
}
545-
auto neutral_vec = builder.create<arith::ConstantOp>(
546-
value.getLoc(), native_vreg_ty,
547-
DenseElementsAttr::get(native_vreg_ty, neutral));
509+
Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral);
548510
return builder
549511
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
550512
.getResult();
@@ -1863,126 +1825,28 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
18631825
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
18641826
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);
18651827

1866-
const VectorType i32_vreg_ty =
1867-
getNativeVregType(builder.getI32Type(), ctx.target_shape);
1868-
auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
1869-
CHECK(dim == 0 || dim == 1);
1870-
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
1871-
return cast<TypedValue<VectorType>>(
1872-
builder
1873-
.create<arith::CmpIOp>(
1874-
arith::CmpIPredicate::slt,
1875-
builder.create<tpu::IotaOp>(i32_vreg_ty,
1876-
builder.getI32IntegerAttr(dim)),
1877-
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
1878-
i32_vreg_ty, builder.getI32IntegerAttr(
1879-
ctx.target_shape[dim] - padding))))
1880-
.getResult());
1881-
};
1882-
1883-
// We can also extend this helper function with padding_top and padding_left
1884-
// based on the offsets in vregs.
1885-
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
1886-
op.getLoc(),
1887-
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
1888-
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
1889-
op.getLoc(), DenseElementsAttr::get(
1890-
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
1891-
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
1892-
int64_t padding_right) {
1893-
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
1894-
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
1895-
// Mask out the bottom.
1896-
if (padding_bottom > 0) {
1897-
// We have limited the row size of LHS and RHS need to be a multiple of
1898-
// native tiling at the beginning of this rule. Therefore, it is safe to
1899-
// bitcast to x32 vreg for masking.
1900-
int sub_padding = padding_bottom % packing;
1901-
int x32_padding_bottom = padding_bottom / packing;
1902-
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
1903-
// Create an int32 vreg which contains subelement masking and then
1904-
// logical_and with target vreg to mask out the unaligned paddings.
1905-
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
1906-
// [8, 128], then the mask will be:
1907-
//
1908-
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1909-
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1910-
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1911-
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1912-
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
1913-
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
1914-
// sublane 6: [0 , 0 , ..., 0 ]
1915-
// sublane 7: [0 , 0 , ..., 0 ]
1916-
//
1917-
// Through this way, in order to mask sub-elements, each target vreg only
1918-
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
1919-
// + packing).
1920-
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
1921-
op.getLoc(),
1922-
DenseElementsAttr::get(
1923-
i32_vreg_ty,
1924-
builder.getI32IntegerAttr(
1925-
0xffffffff >>
1926-
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
1927-
// Insert 0xffffffff above the blended sublane.
1928-
Value sublane_mask = builder.create<arith::SelectOp>(
1929-
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
1930-
partial_sublane_mask);
1931-
// Insert 0 below the blended sublane.
1932-
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
1933-
i32_zeros_vreg);
1934-
for (int64_t i = 0; i < vregs.dim(1); ++i) {
1935-
Value &vreg = vregs({vregs.dim(0) - 1, i});
1936-
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
1937-
if (sub_padding > 0) {
1938-
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
1939-
} else {
1940-
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
1941-
i32_zeros_vreg);
1942-
}
1943-
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
1944-
}
1945-
}
1946-
// Mask out the right.
1947-
if (padding_right > 0) {
1948-
auto mask_right = getX32VmaskByPaddingEnd(1, padding_right);
1949-
for (int64_t i = 0; i < vregs.dim(0); ++i) {
1950-
Value &vreg = vregs({i, vregs.dim(1) - 1});
1951-
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
1952-
i32_vreg = builder.create<arith::SelectOp>(mask_right, i32_vreg,
1953-
i32_zeros_vreg);
1954-
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
1955-
}
1956-
}
1957-
};
1958-
1959-
// Create a vreg filled with zeros.
1960-
auto getZerosVergLike =
1961-
[&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
1962-
const VectorType vreg_type = cast<VectorType>(vreg.getType());
1963-
FAILUREOR_ASSIGN_OR_RETURN(
1964-
const Attribute zero_attr,
1965-
getZeroIntOrFloatAttr(vreg_type.getElementType()));
1966-
return cast<TypedValue<VectorType>>(
1967-
builder
1968-
.create<arith::ConstantOp>(
1969-
op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr))
1970-
.getResult());
1971-
};
1972-
1973-
FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg,
1974-
getZerosVergLike(*lhs_vregs.begin()));
1975-
FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg,
1976-
getZerosVergLike(*rhs_vregs.begin()));
1977-
FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg,
1978-
getZerosVergLike(*acc_vregs.begin()));
1828+
auto lhs_zeros_vreg =
1829+
getZerosVector(builder, cast<VectorType>(lhs_vregs.begin()->getType()));
1830+
auto rhs_zeros_vreg =
1831+
getZerosVector(builder, cast<VectorType>(rhs_vregs.begin()->getType()));
1832+
auto acc_zeros_vreg =
1833+
getZerosVector(builder, cast<VectorType>(acc_vregs.begin()->getType()));
19791834

19801835
// Only mask out the paddings on contracting dim of LHS and RHS.
1981-
maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]);
1836+
RETURN_IF_FAILED(
1837+
maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape,
1838+
/*padding_bottom=*/0,
1839+
/*padding_right=*/padded_lhs_cols - lhs_shape[1]));
19821840
if (transpose_rhs) {
1983-
maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]);
1841+
RETURN_IF_FAILED(maskNativeTilingVregs(
1842+
builder, rhs_vregs, ctx.target_shape,
1843+
/*padding_bottom=*/0,
1844+
/*padding_right=*/padded_rhs_cols - rhs_shape[1]));
19841845
} else {
1985-
maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0);
1846+
RETURN_IF_FAILED(
1847+
maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape,
1848+
/*padding_bottom=*/padded_rhs_rows - rhs_shape[0],
1849+
/*padding_right=*/0));
19861850
}
19871851

19881852
// At this point, all paddings on vregs are masked out. For now, we
@@ -2875,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
28752739
native_vreg_ty,
28762740
/*dimension =*/builder.getI32IntegerAttr(1));
28772741
for (int64_t i = 0; i < num_tiles; ++i) {
2878-
auto offset = builder.create<arith::ConstantOp>(
2879-
native_vreg_ty,
2880-
DenseElementsAttr::get(
2881-
native_vreg_ty,
2882-
IntegerAttr::get(vty.getElementType(),
2883-
i * *(native_vreg_ty.getShape().end() - 1))));
2742+
Value offset = getFullVector(
2743+
builder, native_vreg_ty,
2744+
IntegerAttr::get(vty.getElementType(),
2745+
i * *(native_vreg_ty.getShape().end() - 1)));
28842746
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
28852747
}
28862748
xla::Array<Value> broadcasted_tiles(tile_array_shape);
@@ -2902,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
29022764
native_vreg_ty,
29032765
/*dimension =*/builder.getI32IntegerAttr(0));
29042766
for (int64_t i = 0; i < num_tiles; ++i) {
2905-
auto offset = builder.create<arith::ConstantOp>(
2906-
native_vreg_ty,
2907-
DenseElementsAttr::get(
2908-
native_vreg_ty,
2909-
IntegerAttr::get(vty.getElementType(),
2910-
i * *(native_vreg_ty.getShape().end() - 2))));
2767+
Value offset = getFullVector(
2768+
builder, native_vreg_ty,
2769+
IntegerAttr::get(vty.getElementType(),
2770+
i * *(native_vreg_ty.getShape().end() - 2)));
29112771
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
29122772
}
29132773
xla::Array<Value> broadcasted_tiles(tile_array_shape);
@@ -2924,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
29242784
SmallVector<Value> tiles;
29252785
tiles.reserve(vty.getDimSize(*dimension));
29262786
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
2927-
tiles.push_back(builder.create<arith::ConstantOp>(
2928-
native_vreg_ty,
2929-
DenseElementsAttr::get(native_vreg_ty,
2930-
IntegerAttr::get(vty.getElementType(), i))));
2787+
tiles.push_back(getFullVector(builder, native_vreg_ty,
2788+
IntegerAttr::get(vty.getElementType(), i)));
29312789
}
29322790
xla::Array<Value> out_tiles(tile_array_shape);
29332791
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
@@ -3516,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
35163374
const int64_t offset = *offsets_in[1];
35173375
const int64_t lane_offset = offset % ctx.target_shape[1];
35183376
const int64_t tile_offset = offset / ctx.target_shape[1];
3519-
const auto idx_ty =
3520-
VectorType::get(ctx.target_shape, builder.getI32Type());
3521-
auto lane_offset_cst = builder.create<arith::ConstantOp>(
3522-
broadcast_op.getLoc(), idx_ty,
3523-
DenseElementsAttr::get(idx_ty,
3524-
builder.getI32IntegerAttr(lane_offset)));
3377+
Value lane_offset_cst = getFullVector(
3378+
builder, getNativeVregType(builder.getI32Type(), ctx.target_shape),
3379+
builder.getI32IntegerAttr(lane_offset));
35253380
DenseI32ArrayAttr sublane_pattern;
35263381
if (num_tiles != 1) {
35273382
SmallVector<int32_t> pattern;
@@ -3581,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
35813436
getNativeVregType(src_i32.getType(), ctx.target_shape);
35823437
auto tile_i32 =
35833438
builder.create<vector::BroadcastOp>(native_vreg_ty, src_i32);
3584-
auto zeros = builder.create<arith::ConstantOp>(
3585-
broadcast_op.getLoc(), tile_i32.getType(),
3586-
DenseElementsAttr::get(tile_i32.getType(),
3587-
builder.getI32IntegerAttr(0)));
3439+
Value zeros = getZerosVector(builder, tile_i32.getType());
35883440
auto tile =
35893441
builder.create<arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
35903442
.getResult();

0 commit comments

Comments
 (0)