Skip to content

Commit 8835e8d

Browse files
authored
Blocking support for vector.transpose (#743)
Transpose blocking
1 parent 85c4be7 commit 8835e8d

File tree

5 files changed

+197
-9
lines changed

5 files changed

+197
-9
lines changed

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,48 @@ struct SgUpdateTileOffsetOpPattern
718718
}
719719
};
720720

721+
struct SgTransposeOpPattern
722+
: public SgXeTileToXeGPUConversion<mlir::vector::TransposeOp> {
723+
using SgXeTileToXeGPUConversion::SgXeTileToXeGPUConversion;
724+
725+
mlir::LogicalResult
726+
matchAndRewrite(mlir::vector::TransposeOp op, OpAdaptor adaptor,
727+
XeGPUOneToNPatterRewriter &rewriter) const override {
728+
auto resType = op.getResult().getType();
729+
if (resType.getRank() != 4)
730+
return ((mlir::PatternRewriter &)rewriter)
731+
.notifyMatchFailure(op, "Expected a 4D vector");
732+
733+
auto srcVectors = adaptor.getVector();
734+
auto shape = resType.getShape();
735+
if (shape[0] * shape[1] != static_cast<int64_t>(srcVectors.size()))
736+
return ((mlir::PatternRewriter &)rewriter)
737+
.notifyMatchFailure(op, "Invalid shape");
738+
739+
auto permutation = op.getPermutation();
740+
auto outerPerm = permutation.take_front(2);
741+
int64_t innerPerm[2] = {permutation[2] - 2, permutation[3] - 2};
742+
743+
auto newResType =
744+
mlir::VectorType::get(shape.take_back(2), resType.getElementType());
745+
746+
mlir::Location loc = op.getLoc();
747+
llvm::SmallVector<mlir::Value> results;
748+
for (auto i : llvm::seq<size_t>(0, shape[0])) {
749+
for (auto j : llvm::seq<size_t>(0, shape[1])) {
750+
size_t ij[2] = {i, j};
751+
auto idx = ij[outerPerm[1]] + shape[outerPerm[1]] * ij[outerPerm[0]];
752+
mlir::Value arg = srcVectors[idx];
753+
mlir::Value res = rewriter.create<mlir::vector::TransposeOp>(
754+
loc, newResType, arg, innerPerm);
755+
results.emplace_back(res);
756+
}
757+
}
758+
rewriter.replaceOp(op, results);
759+
return mlir::success();
760+
}
761+
};
762+
721763
bool isLegalElementWiseOp(mlir::Operation *op) {
722764
auto res = op->getResult(0);
723765
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
@@ -801,8 +843,8 @@ void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter,
801843
patterns.insert<SgInitTileOpPattern, SgPrefetchTileOpPattern,
802844
SgTileUnpackOpPattern, SgTilePackOpPattern,
803845
SgLoadTileOpPattern, SgStoreTileOpPattern, SgTileMMAOpPattern,
804-
SgUpdateTileOffsetOpPattern>(patterns.getContext(), converter,
805-
analysis);
846+
SgUpdateTileOffsetOpPattern, SgTransposeOpPattern>(
847+
patterns.getContext(), converter, analysis);
806848
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
807849
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
808850
ElementWiseOpPattern<mlir::math::SinOp, 1>,

lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
125125
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
126126
addDynamicallyLegalOp<mlir::math::TanhOp>(
127127
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
128+
129+
addDynamicallyLegalOp<mlir::vector::TransposeOp>(
130+
[](mlir::vector::TransposeOp op) {
131+
return op.getResult().getType().getRank() == 2;
132+
});
128133
}
129134

130135
private:

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ populateXeTileBlockAligningPatterns(imex::XeTypeConverter &converter,
5858
mlir::RewritePatternSet &patterns,
5959
PropagateAnalysis &analysis);
6060

61-
enum OpType { Prefetch, Load, Store, Elementwise };
61+
enum OpType { Prefetch, Load, Store, Elementwise, Transpose };
6262

6363
// Find the maximum divisible number between minHeight/Width and maxHeight/Width
6464
// and use that as the inner block sizes.
@@ -170,8 +170,26 @@ getInnerBlockSizes(mlir::Operation *operation, mlir::Type elemTy, int height,
170170
// TODO: get from uArch?
171171
int64_t subgroupSize = 16;
172172

173-
return {1, subgroupSize};
173+
maxHeight = 1;
174+
minHeight = 1;
175+
maxWidth = subgroupSize;
176+
minWidth = 1;
177+
178+
return imex::getInnerBlockHeightWidth(maxHeight, maxWidth, minHeight,
179+
minWidth, height, width);
180+
}
181+
182+
if (op == OpType::Transpose) {
183+
// TODO: get from uArch?
184+
maxHeight = 16;
185+
minHeight = 1;
186+
maxWidth = 16;
187+
minWidth = 1;
188+
189+
return imex::getInnerBlockHeightWidth(maxHeight, maxWidth, minHeight,
190+
minWidth, height, width);
174191
}
192+
175193
llvm_unreachable("Unsupported.");
176194
return {};
177195
}
@@ -368,6 +386,70 @@ struct VectorizableOpPattern
368386
}
369387
};
370388

389+
struct TransposeOpPattern
390+
: public XeTileConversion<mlir::vector::TransposeOp, TileUsageAnalysis> {
391+
392+
using XeTileConversion::XeTileConversion;
393+
394+
TransposeOpPattern(mlir::MLIRContext *context,
395+
imex::XeTypeConverter &converter,
396+
TileUsageAnalysis &analysis,
397+
std::shared_ptr<XeuArchInterface> ptruArch)
398+
: XeTileConversion(context, converter, analysis) {
399+
this->uArchInterface = ptruArch;
400+
}
401+
402+
std::shared_ptr<XeuArchInterface> uArchInterface = nullptr;
403+
404+
mlir::LogicalResult
405+
matchAndRewrite(mlir::vector::TransposeOp op, OpAdaptor adaptor,
406+
mlir::PatternRewriter &rewriter) const override {
407+
auto res = op.getResult();
408+
auto resType = mlir::cast<mlir::VectorType>(res.getType());
409+
if (resType.getRank() != 2)
410+
return rewriter.notifyMatchFailure(op, "type is not 2D vector");
411+
412+
auto permutation = op.getPermutation();
413+
if (permutation != mlir::ArrayRef<int64_t>({1, 0}))
414+
return rewriter.notifyMatchFailure(op, "Unsupported permutation");
415+
416+
auto shape = resType.getShape();
417+
auto blocks = getInnerBlockSizes<Transpose>(
418+
op, resType.getElementType(), shape[0], shape[1], this->uArchInterface);
419+
420+
if (blocks.size() != 2)
421+
return rewriter.notifyMatchFailure(op, "Invalid inner block sizes");
422+
423+
int64_t inBlocks[2] = {blocks[1], blocks[0]};
424+
425+
auto newSrcTy = mlir::VectorType::get(
426+
{shape[1] / blocks[1], shape[0] / blocks[0], blocks[1], blocks[0]},
427+
resType.getElementType());
428+
429+
auto newDstTy = mlir::VectorType::get(
430+
{shape[0] / blocks[0], shape[1] / blocks[1], blocks[0], blocks[1]},
431+
resType.getElementType());
432+
433+
mlir::Value arg = adaptor.getVector();
434+
Location loc = op->getLoc();
435+
mlir::Value pack = rewriter.create<xetile::TilePackOp>(
436+
loc, newSrcTy, arg,
437+
mlir::DenseI64ArrayAttr::get(getContext(), inBlocks));
438+
439+
int64_t newPermutation[4] = {1, 0, 3, 2};
440+
mlir::Value transpose = rewriter.create<mlir::vector::TransposeOp>(
441+
loc, newDstTy, pack, newPermutation);
442+
443+
mlir::Value unpack = rewriter.create<xetile::TileUnpackOp>(
444+
loc, resType, transpose,
445+
mlir::DenseI64ArrayAttr::get(getContext(), blocks));
446+
447+
rewriter.replaceOp(op, unpack);
448+
449+
return mlir::success();
450+
}
451+
};
452+
371453
struct VectorMultiDimReductionOpPattern
372454
: public XeTileConversion<mlir::vector::MultiDimReductionOp,
373455
TileUsageAnalysis> {
@@ -873,11 +955,12 @@ struct UpdateTileOffsetOpPattern
873955
void populateXeTileBlockingPatterns(
874956
imex::XeTypeConverter &converter, mlir::RewritePatternSet &patterns,
875957
TileUsageAnalysis &analysis, std::shared_ptr<XeuArchInterface> ptruArch) {
876-
patterns.insert<ArithConstantOpPattern, VectorizableOpPattern,
877-
SCFForOpPattern, SCFYieldOpPattern, InitTileOpPattern,
878-
LoadTileOpPattern, StoreTileOpPattern, TileMMAOpPattern,
879-
UpdateTileOffsetOpPattern, VectorMultiDimReductionOpPattern>(
880-
patterns.getContext(), converter, analysis, ptruArch);
958+
patterns
959+
.insert<ArithConstantOpPattern, VectorizableOpPattern, SCFForOpPattern,
960+
SCFYieldOpPattern, InitTileOpPattern, LoadTileOpPattern,
961+
StoreTileOpPattern, TileMMAOpPattern, UpdateTileOffsetOpPattern,
962+
TransposeOpPattern, VectorMultiDimReductionOpPattern>(
963+
patterns.getContext(), converter, analysis, ptruArch);
881964
}
882965

883966
// Lowers XeTile to blocked layout with high-dim vector

test/Conversion/XeTileToXeGPU/test_blocking.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,37 @@ func.func @test_blocking_elementwise(%a: vector<64x64xf16>, %b: vector<64x64xf16
2424
}
2525

2626
}
27+
28+
// -----
29+
30+
gpu.module @test_kernel {
31+
32+
// CHECK-LABEL: test_blocking_transpose
33+
// CHECK-SAME: (%[[SRC:.*]]: vector<64x32xf16>)
34+
// CHECK: %[[PACK:.*]] = xetile.tile_pack %[[SRC]] { inner_blocks = [16, 16] } : vector<64x32xf16> -> vector<4x2x16x16xf16>
35+
// CHECK: %[[T:.*]] = vector.transpose %[[PACK]], [1, 0, 3, 2] : vector<4x2x16x16xf16> to vector<2x4x16x16xf16>
36+
// CHECK: %[[UNPACK:.*]] = xetile.tile_unpack %[[T]] { inner_blocks = [16, 16] } : vector<2x4x16x16xf16> -> vector<32x64xf16>
37+
// CHECK: return %[[UNPACK]] : vector<32x64xf16>
38+
func.func @test_blocking_transpose(%a: vector<64x32xf16>) -> vector<32x64xf16> {
39+
%0 = vector.transpose %a, [1, 0]: vector<64x32xf16> to vector<32x64xf16>
40+
return %0 : vector<32x64xf16>
41+
}
42+
43+
}
44+
45+
// -----
46+
47+
gpu.module @test_kernel {
48+
49+
// CHECK-LABEL: test_blocking_transpose_small
50+
// CHECK-SAME: (%[[SRC:.*]]: vector<16x8xf16>)
51+
// CHECK: %[[PACK:.*]] = xetile.tile_pack %[[SRC]] { inner_blocks = [16, 8] } : vector<16x8xf16> -> vector<1x1x16x8xf16>
52+
// CHECK: %[[T:.*]] = vector.transpose %[[PACK]], [1, 0, 3, 2] : vector<1x1x16x8xf16> to vector<1x1x8x16xf16>
53+
// CHECK: %[[UNPACK:.*]] = xetile.tile_unpack %[[T]] { inner_blocks = [8, 16] } : vector<1x1x8x16xf16> -> vector<8x16xf16>
54+
// CHECK: return %[[UNPACK]] : vector<8x16xf16>
55+
func.func @test_blocking_transpose_small(%a: vector<16x8xf16>) -> vector<8x16xf16> {
56+
%0 = vector.transpose %a, [1, 0]: vector<16x8xf16> to vector<8x16xf16>
57+
return %0 : vector<8x16xf16>
58+
}
59+
60+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: test_transpose
4+
// Compare original args order with transposed
5+
// CHECK: %[[RES1:.*]] = builtin.unrealized_conversion_cast %[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]], %[[ARG4:.*]], %[[ARG5:.*]], %[[ARG6:.*]], %[[ARG7:.*]], %[[ARG8:.*]] :
6+
// CHECK-DAG: %[[TARG1:.*]] = vector.transpose %[[ARG1]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
7+
// CHECK-DAG: %[[TARG2:.*]] = vector.transpose %[[ARG2]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
8+
// CHECK-DAG: %[[TARG3:.*]] = vector.transpose %[[ARG3]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
9+
// CHECK-DAG: %[[TARG4:.*]] = vector.transpose %[[ARG4]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
10+
// CHECK-DAG: %[[TARG5:.*]] = vector.transpose %[[ARG5]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
11+
// CHECK-DAG: %[[TARG6:.*]] = vector.transpose %[[ARG6]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
12+
// CHECK-DAG: %[[TARG7:.*]] = vector.transpose %[[ARG7]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
13+
// CHECK-DAG: %[[TARG8:.*]] = vector.transpose %[[ARG8]], [1, 0] : vector<1x16xf16> to vector<16x1xf16>
14+
// CHECK: %[[RES2:.*]] = builtin.unrealized_conversion_cast %[[TARG1]], %[[TARG5]], %[[TARG2]], %[[TARG6]], %[[TARG3]], %[[TARG7]], %[[TARG4]], %[[TARG8]]
15+
// CHECK: gpu.return %[[RES1]], %[[RES2]]
16+
gpu.module @test_kernel {
17+
gpu.func @test_transpose(%a: memref<2x64xf16>) -> (vector<2x4x1x16xf16>, vector<4x2x16x1xf16>) {
18+
%c0 = arith.constant 0 : index
19+
%0 = xetile.init_tile %a[%c0, %c0] : memref<2x64xf16> -> !xetile.tile<2x64xf16, #xetile.tile_attr<inner_blocks = [1, 16]>>
20+
%1 = xetile.load_tile %0 : !xetile.tile<2x64xf16, #xetile.tile_attr<inner_blocks = [1, 16]>> -> vector<2x4x1x16xf16>
21+
%2 = vector.transpose %1, [1, 0, 3, 2] : vector<2x4x1x16xf16> to vector<4x2x16x1xf16>
22+
gpu.return %1, %2 : vector<2x4x1x16xf16>, vector<4x2x16x1xf16>
23+
}
24+
}

0 commit comments

Comments
 (0)