Skip to content

Commit ab01f17

Browse files
authored
[Blocking] Emulate the unpack and pack on vectors using insert_strided_slice and extract_strided_slice (#1023)
1 parent 71e01d3 commit ab01f17

File tree

15 files changed

+705
-1049
lines changed

15 files changed

+705
-1049
lines changed

include/imex/Utils/XeCommon.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,6 @@ applyVnniTransform(mlir::OpBuilder &builder,
8484
// 16, 32, and 64 are only available if simdLanes == 1.
8585
llvm::SmallVector<int> getSupportedChunkSizes(int simdlanes);
8686

87-
using PackFuncTy = std::function<mlir::TypedValue<mlir::VectorType>(
88-
mlir::Value, mlir::Value, mlir::Location, mlir::OpBuilder &)>;
89-
90-
// A wrapper function to merge small vectors into a big one. It takes a
91-
// range of mlir::Value objects with mlir::VectorType, and merge them
92-
// into a big vector using the provided transformation function.
93-
mlir::Value packVectorsWith(mlir::ValueRange ins, PackFuncTy op,
94-
mlir::Location loc, mlir::OpBuilder &builder);
95-
9687
// Combine vectors vertically while keeping the logical data layout.
9788
// As an example, given two vectors (2x4xf16) p and q, it will merge
9889
// them in to a 4x4xf16 vector.
@@ -105,18 +96,6 @@ mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
10596
mlir::Location loc,
10697
mlir::OpBuilder &builder);
10798

108-
// merge vectors horizontally while keep the logical data layout.
109-
// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12
110-
// 5 6 7 8 13 14 15 5 6 7 8 13 14 15
111-
// since there is no direct op in mlir exists, we will
112-
// using ShapeCast and Shuffle to mimic it. It comes with
113-
// cost of complex shuffle masks. the mask for the above one
114-
// will be like this: 0 1 2 3 8 9 10
115-
// 4 5 6 7 11 12 13
116-
mlir::TypedValue<mlir::VectorType> concat(mlir::Value lhs, mlir::Value rhs,
117-
mlir::Location loc,
118-
mlir::OpBuilder &builder);
119-
12099
// It checks each GPUFuncOp in the module to see
121100
// whether they have arguments and outputs with
122101
// xetile.TileType. They are currently not supported yet.

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 95 additions & 268 deletions
Large diffs are not rendered by default.

lib/Transforms/VectorLinearize.cpp

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,17 @@ struct VectorExtractStridedSliceConversion final
322322
}
323323
};
324324

325+
// clang-format off
326+
// linearize InsertStridedSliceOp by extracting rows from the source vector
327+
// using extract_strided_slice and inserting them into the destination vector
328+
// using insert_strided_slice. For example.
329+
// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
330+
// will lowered into (both s and d are linearized to 1D):
331+
// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
332+
// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
333+
// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
334+
// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
335+
// clang-format on
325336
struct VectorInsertStridedSliceConversion final
326337
: public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
327338
using mlir::OpConversionPattern<
@@ -330,31 +341,47 @@ struct VectorInsertStridedSliceConversion final
330341
mlir::LogicalResult
331342
matchAndRewrite(mlir::vector::InsertStridedSliceOp op, OpAdaptor adaptor,
332343
mlir::ConversionPatternRewriter &rewriter) const override {
344+
auto loc = op.getLoc();
333345
auto srcTy = op.getSourceVectorType();
334-
auto destTy = op.getDestVectorType();
346+
auto dstTy = op.getDestVectorType();
335347

336348
if (op.hasNonUnitStrides()) {
337349
return rewriter.notifyMatchFailure(
338-
op, "InsertStridedSliceOp only supports unit strides.");
350+
op, "InsertStridedSliceOp linearization only supports unit strides.");
339351
}
340352

341-
if (llvm::any_of(srcTy.getShape().drop_back(),
342-
[](int64_t dim) { return dim != 1; })) {
343-
return rewriter.notifyMatchFailure(op,
344-
"Only supports vectors with leading "
345-
"dims (except the last dim) as 1s.");
353+
if (srcTy.getRank() != 2) {
354+
return rewriter.notifyMatchFailure(
355+
op, "InsertStridedSliceOp linearization only supports 2D source.");
346356
}
347357

348-
auto strides = destTy.getShape().drop_front().vec();
349-
strides.push_back(1);
358+
if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
359+
return rewriter.notifyMatchFailure(
360+
op, "InsertStridedSliceOp linerization only supports static shapes.");
361+
}
362+
363+
auto dstShape = dstTy.getShape();
364+
auto dstStrides = dstShape.drop_front().vec();
365+
dstStrides.push_back(1);
350366
int64_t linearizedOffset = 0;
351-
for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), strides)) {
367+
for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
352368
linearizedOffset += mlir::getConstantIntValue(off).value() * stride;
353369
}
354370

355-
rewriter.replaceOpWithNewOp<mlir::vector::InsertStridedSliceOp>(
356-
op, adaptor.getSource(), adaptor.getDest(), linearizedOffset, 1);
371+
// extracts a row from source, and insert it into the destination
372+
auto srcShape = srcTy.getShape();
373+
mlir::Value dstValue = adaptor.getDest();
374+
for (auto i = 0; i < srcShape[0]; i++) {
375+
auto srcOffset = i * srcShape[1];
376+
auto value = rewriter.create<mlir::vector::ExtractStridedSliceOp>(
377+
loc, adaptor.getSource(), srcOffset, srcShape[1], 1);
378+
379+
auto dstOffset = linearizedOffset + i * dstShape.back();
380+
dstValue = rewriter.create<mlir::vector::InsertStridedSliceOp>(
381+
loc, value, dstValue, dstOffset, 1);
382+
}
357383

384+
rewriter.replaceOp(op, dstValue);
358385
return mlir::success();
359386
}
360387
};
@@ -672,9 +699,9 @@ struct VectorLinearizePass final
672699
target.addDynamicallyLegalOp<mlir::vector::InsertStridedSliceOp>(
673700
[&](mlir::vector::InsertStridedSliceOp op) {
674701
auto srcTy = op.getSourceVectorType();
675-
if (!op.hasNonUnitStrides() && srcTy.getRank() != 1 &&
676-
llvm::all_of(srcTy.getShape().drop_back(),
677-
[](int64_t dim) { return dim == 1; }))
702+
auto dstTy = op.getDestVectorType();
703+
if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
704+
srcTy.hasStaticShape() && dstTy.hasStaticShape())
678705
return false;
679706
return true;
680707
});

lib/Utils/XeCommon.cpp

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -318,94 +318,6 @@ mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
318318
return op;
319319
}
320320

321-
// generate linearized shuffle mask for concat.
322-
static llvm::SmallVector<int64_t>
323-
getShuffleMask(llvm::ArrayRef<int64_t> shape1, llvm::ArrayRef<int64_t> shape2) {
324-
assert(shape1.size() == shape2.size() && shape1.size() <= 2 &&
325-
"only 1D/2D shape are supported.");
326-
assert(shape1.drop_back() == shape2.drop_back() &&
327-
"the row dim of the shapes should match.");
328-
int64_t size1 = std::accumulate(shape1.begin(), shape1.end(), 1,
329-
std::multiplies<int64_t>());
330-
int64_t size2 = std::accumulate(shape2.begin(), shape2.end(), 1,
331-
std::multiplies<int64_t>());
332-
llvm::SmallVector<int64_t> mask(size1 + size2);
333-
auto rows = shape1.size() == 1 ? 1 : shape1[0];
334-
auto cols1 = shape1.size() == 1 ? shape1[0] : shape1[1];
335-
auto cols2 = shape2.size() == 1 ? shape2[0] : shape2[1];
336-
for (int64_t i = 0; i < rows; i++) {
337-
int64_t s = i * (cols1 + cols2);
338-
int64_t m = s + cols1;
339-
int64_t e = m + cols2;
340-
int64_t v1 = i * cols1;
341-
int64_t v2 = size1 + i * cols2;
342-
std::iota(mask.begin() + s, mask.begin() + m, v1);
343-
std::iota(mask.begin() + m, mask.begin() + e, v2);
344-
}
345-
return mask;
346-
}
347-
348-
mlir::TypedValue<mlir::VectorType> concat(mlir::Value lhs, mlir::Value rhs,
349-
mlir::Location loc,
350-
mlir::OpBuilder &builder) {
351-
auto lhsTy = llvm::cast<mlir::VectorType>(lhs.getType());
352-
auto rhsTy = llvm::cast<mlir::VectorType>(rhs.getType());
353-
354-
assert(lhsTy.getShape()[0] == lhsTy.getShape()[0] &&
355-
"Operands of concat() do not have the same number of rows.");
356-
assert(lhsTy.getRank() <= 2 && rhsTy.getRank() == lhsTy.getRank() &&
357-
"Currently concat only works on 1D/2D vector.");
358-
359-
auto elemTy = lhsTy.getElementType();
360-
auto leftSize = lhsTy.getNumElements();
361-
auto leftShape = lhsTy.getShape();
362-
auto leftFlatTy = mlir::VectorType::get({lhsTy.getNumElements()}, elemTy);
363-
364-
auto rightSize = rhsTy.getNumElements();
365-
auto rightShape = rhsTy.getShape();
366-
auto rightFlatTy = mlir::VectorType::get({rhsTy.getNumElements()}, elemTy);
367-
368-
auto newShape = lhsTy.getRank() == 1
369-
? llvm::SmallVector<int64_t>({leftSize + rightSize})
370-
: llvm::SmallVector<int64_t>(
371-
{leftShape[0], leftShape[1] + rightShape[1]});
372-
auto castLeft =
373-
builder.create<mlir::vector::ShapeCastOp>(loc, leftFlatTy, lhs);
374-
auto castRight =
375-
builder.create<mlir::vector::ShapeCastOp>(loc, rightFlatTy, rhs);
376-
auto mask = getShuffleMask(leftShape, rightShape);
377-
auto shuffleOp =
378-
builder.create<mlir::vector::ShuffleOp>(loc, castLeft, castRight, mask);
379-
auto targetTy = mlir::VectorType::get(newShape, elemTy);
380-
auto newOp =
381-
builder.create<mlir::vector::ShapeCastOp>(loc, targetTy, shuffleOp);
382-
return newOp;
383-
}
384-
385-
// A wrapper function to merge small vectors into a big one. It takes a
386-
// range of mlir::Value objects with mlir::VectorType, and merge them
387-
// into a big vector using the provided transformation function.
388-
mlir::Value packVectorsWith(mlir::ValueRange ins, PackFuncTy op,
389-
mlir::Location loc, mlir::OpBuilder &builder) {
390-
llvm::SmallVector<mlir::Value> shuffleOps(ins.begin(), ins.end());
391-
while (shuffleOps.size() > 1) {
392-
auto curr = shuffleOps;
393-
shuffleOps.clear();
394-
size_t currPairStartIdx{0};
395-
while (currPairStartIdx < curr.size() - 1) {
396-
size_t leftIdx{currPairStartIdx++};
397-
size_t rightIdx{currPairStartIdx++};
398-
auto newOp = op(curr[leftIdx], curr[rightIdx], loc, builder);
399-
shuffleOps.push_back(newOp);
400-
}
401-
if (currPairStartIdx < curr.size()) {
402-
assert(currPairStartIdx == curr.size() - 1);
403-
shuffleOps.push_back(curr[curr.size() - 1]);
404-
}
405-
}
406-
return shuffleOps[0];
407-
}
408-
409321
/// Checks if the given `type` is a 1-D vector type that requires VectorAnyINTEL
410322
/// capability. In other words, the vector size is not supported by SPIR-V.
411323
/// SPIR-V only supports 2, 3, 4, 8, 16 elements (8 and 16 with Vector16

test/Conversion/XeTileToXeGPU/gemm_preop.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ module attributes {gpu.container_module} {
6262
%27 = xetile.load_tile %arg5 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16>
6363
%28 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16>
6464
xegpu.compile_hint
65-
//CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [{{.*}}], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16>
65+
//CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [{{.*}}], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16>
6666
//CHECK-COUNT-8: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<8x16xf16>
6767
%29 = arith.addf %27, %27 : vector<32x32xf16>
6868
xegpu.compile_hint
6969
%30 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16>
7070
%31 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16>
7171
xegpu.compile_hint
7272

73+
//CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [{{.*}}], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16>
7374
// CHECK-COUNT-16: {{.*}} = xegpu.dpas {{.*}}, {{.*}}, {{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
7475
%32 = xetile.tile_mma %29, %28, %arg7 : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32>
7576
xegpu.compile_hint

test/Conversion/XeTileToXeGPU/sg_softmax.mlir

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ gpu.module @test_kernel {
1818
//CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
1919
%2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16>
2020

21-
//CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16>
22-
//CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16>
21+
//CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16>
2322
//CHECK-COUNT-8: {{.*}} = math.exp %{{.*}} : vector<8x32xf16>
2423
%3 = math.exp %2: vector<32x64xf16>
2524
//CHECK-COUNT-62: arith.addf {{.*}}, {{.*}} : vector<1x32xf16>
@@ -42,8 +41,7 @@ gpu.module @test_kernel {
4241
//CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
4342
//CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
4443
%2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16>
45-
//CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16>
46-
//CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16>
44+
//CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16>
4745
//CHECK-COUNT-8: {{.*}} = math.exp %{{.*}} : vector<8x32xf16>
4846
%3 = math.exp %2: vector<32x64xf16>
4947
//CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x32xf16>
@@ -203,22 +201,13 @@ gpu.module @test_kernel {
203201
//CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62] : vector<32xf16>, vector<32xf16>
204202
//CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63] : vector<32xf16>, vector<32xf16>
205203
//CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<32xf16>
206-
//CHECK-COUNT-32: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16>
204+
//CHECK-COUNT-32: {{.*}} = vector.extractelement {{.*}}[{{.*}} : index] : vector<32xf16>
207205
//CHECK-COUNT-32: {{.*}} = vector.splat {{.*}} : vector<1x32xf16>
208206
%4 = xetile.reduction <add>, %3 [1]: vector<32x64xf16> -> vector<32x1xf16>
209207

210-
//CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16>
211-
//CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16>
212-
//CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16>
213-
//CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16>
214-
//CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16>
215-
//CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16>
216-
//CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16>
217-
//CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16>
218-
//CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16>
219-
//CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16>
220-
//CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16>
221-
//CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16>
208+
//CHECK-COUNT-64: %{{.*}} = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<1x32xf16> into vector<32x64xf16>
209+
//CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x64xf16> to vector<8x32xf16>
210+
222211
%5 = xetile.broadcast %4 [1]: vector<32x1xf16> -> vector<32x64xf16>
223212
// CHECK-COUNT-8: {{.*}} = arith.divf {{.*}}, {{.*}} : vector<8x32xf16>
224213
%6 = arith.divf %3, %5: vector<32x64xf16>

test/Conversion/XeTileToXeGPU/sg_store_tile.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \
2-
// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s
2+
// RUN: --cse --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s
33

44
gpu.module @test_kernel {
55
//CHECK: gpu.func @sg_tiled_store(%[[arg0:.*]]: memref<1024x1024xf32>) {

0 commit comments

Comments
 (0)