Skip to content

Commit 8949dc7

Browse files
authored
[mlir][amdgpu] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (#152277)
1 parent 7f1638e commit 8949dc7

File tree

2 files changed

+143
-57
lines changed

2 files changed

+143
-57
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final
2828
}
2929
};
3030

31+
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
32+
Value view, mlir::OperandRange indices,
33+
SmallVectorImpl<Value> &resolvedIndices,
34+
Value &memrefBase, StringRef role) {
35+
Operation *defOp = view.getDefiningOp();
36+
if (!defOp) {
37+
return failure();
38+
}
39+
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
40+
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
41+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
42+
rewriter, loc, subviewOp.getMixedOffsets(),
43+
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
44+
resolvedIndices);
45+
memrefBase = subviewOp.getSource();
46+
return success();
47+
})
48+
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
49+
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
50+
loc, rewriter, expandShapeOp, indices, resolvedIndices,
51+
false))) {
52+
return failure();
53+
}
54+
memrefBase = expandShapeOp.getViewSource();
55+
return success();
56+
})
57+
.Case<memref::CollapseShapeOp>(
58+
[&](memref::CollapseShapeOp collapseShapeOp) {
59+
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
60+
loc, rewriter, collapseShapeOp, indices,
61+
resolvedIndices))) {
62+
return failure();
63+
}
64+
memrefBase = collapseShapeOp.getViewSource();
65+
return success();
66+
})
67+
.Default([&](Operation *op) {
68+
return rewriter.notifyMatchFailure(
69+
op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
70+
"CollapseShapeOp")
71+
.str());
72+
});
73+
}
74+
3175
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
3276
using OpRewritePattern::OpRewritePattern;
3377
LogicalResult matchAndRewrite(GatherToLDSOp op,
3478
PatternRewriter &rewriter) const override {
3579
Location loc = op.getLoc();
3680

37-
Value memrefSource;
38-
SmallVector<Value> sourceIndices;
39-
auto foldResult =
40-
llvm::TypeSwitch<Operation *, LogicalResult>(
41-
op.getSrc().getDefiningOp())
42-
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
43-
// If the source is a SubViewOp, we can directly rewrite the
44-
// GatherToLDSOp.
45-
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
46-
rewriter, loc, subviewOp.getMixedOffsets(),
47-
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
48-
op.getSrcIndices(), sourceIndices);
49-
memrefSource = subviewOp.getSource();
50-
return success();
51-
})
52-
.Case<memref::ExpandShapeOp>(
53-
[&](memref::ExpandShapeOp expandShapeOp) {
54-
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
55-
loc, rewriter, expandShapeOp, op.getSrcIndices(),
56-
sourceIndices, false))) {
57-
return failure();
58-
}
59-
memrefSource = expandShapeOp.getViewSource();
60-
return success();
61-
})
62-
.Case<memref::CollapseShapeOp>(
63-
[&](memref::CollapseShapeOp collapseShapeOp) {
64-
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
65-
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
66-
sourceIndices))) {
67-
return failure();
68-
}
69-
memrefSource = collapseShapeOp.getViewSource();
70-
return success();
71-
})
72-
.Default([&](Operation *op) {
73-
// If the source is not a SubViewOp, ExpandShapeOp, or
74-
// CollapseShapeOp, we cannot fold the GatherToLDSOp.
75-
return rewriter.notifyMatchFailure(
76-
op,
77-
"source producer is not one of SubViewOp, ExpandShapeOp, or "
78-
"CollapseShapeOp");
79-
});
81+
SmallVector<Value> sourceIndices, destIndices;
82+
Value memrefSource, memrefDest;
83+
84+
auto foldSrcResult =
85+
foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
86+
sourceIndices, memrefSource, "source");
87+
88+
if (failed(foldSrcResult)) {
89+
memrefSource = op.getSrc();
90+
sourceIndices = op.getSrcIndices();
91+
}
92+
93+
auto foldDstResult =
94+
foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
95+
destIndices, memrefDest, "destination");
8096

81-
if (failed(foldResult)) {
82-
return failure();
97+
if (failed(foldDstResult)) {
98+
memrefDest = op.getDst();
99+
destIndices = op.getDstIndices();
83100
}
84101

85102
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
86-
op.getDst(), op.getDstIndices(),
103+
memrefDest, destIndices,
87104
op.getTransferType());
88105

89106
return success();

mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
5454
// CHECK: func @test_expand_shape
5555
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
5656
func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
57-
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
57+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
5858
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
5959
// CHECK: %[[C0:.*]] = arith.constant 0 : index
60-
// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
61-
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
62-
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
60+
// CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
61+
// CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
62+
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
63+
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
6364

64-
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
65+
%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
6566
%mem = memref.alloc() : memref<8192xf16>
66-
%expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
67+
%expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
68+
%expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
6769
%c0 = arith.constant 0 : index
68-
amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
70+
amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
6971
: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
7072
func.return
7173
}
@@ -80,15 +82,82 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
8082
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
8183
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
8284
// CHECK: %[[C0:.*]] = arith.constant 0 : index
83-
// CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
84-
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
85+
// CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
86+
// CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
87+
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
8588
// CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
8689

8790
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
91+
%collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
8892
%mem = memref.alloc() : memref<64x128xf16>
89-
%collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
93+
%collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
9094
%c0 = arith.constant 0 : index
91-
amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
95+
amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
96+
: vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
97+
func.return
98+
}
99+
100+
101+
// -----
102+
103+
#gpu_lds_addrspace = 3
104+
105+
106+
// CHECK: func @test_expand_shape_src_raw_buffer
107+
// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
108+
func.func @test_expand_shape_src_raw_buffer(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
109+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
110+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
111+
// CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG1]], %[[ARG2]]] by (64, 128) : index
112+
// CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[IDXM]]], %[[LOCAL]][%[[C0]]]
113+
// CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>
114+
115+
%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
116+
%expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>> into memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>
117+
118+
%c0 = arith.constant 0 : index
119+
amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %alloc[%c0]
120+
: vector<8xf16>, memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
121+
func.return
122+
}
123+
124+
// -----
125+
126+
#gpu_lds_addrspace = 3
127+
128+
// CHECK: func @test_expand_shape_dst_only
129+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
130+
func.func @test_expand_shape_dst_only(%offset_i: index, %offset_j: index) {
131+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
132+
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
133+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
134+
// CHECK: %[[IDX_LDS:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (64, 64) : index
135+
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]]], %[[LOCAL]][%[[IDX_LDS]]]
136+
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
137+
138+
%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
139+
%mem = memref.alloc() : memref<8192xf16>
140+
%expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
141+
142+
%c0 = arith.constant 0 : index
143+
amdgpu.gather_to_lds %mem[%offset_i], %expand_alloc[%offset_j, %c0]
92144
: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
93145
func.return
94146
}
147+
148+
// -----
149+
150+
#gpu_lds_addrspace = 3
151+
152+
// CHECK: func @test_nop
153+
// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
154+
func.func @test_nop(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
155+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
156+
// CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[ARG1]]], %[[LOCAL]][%[[ARG2]]]
157+
// CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>
158+
159+
%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
160+
amdgpu.gather_to_lds %mem[%offset_i], %alloc[%offset_j]
161+
: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
162+
func.return
163+
}

0 commit comments

Comments
 (0)