Skip to content

Commit 27c5497

Browse files
committed
Adding dynamic size check to avoid subword buffer load
1 parent bb1f32d commit 27c5497

File tree

5 files changed

+137
-19
lines changed

5 files changed

+137
-19
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
5454
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
5555
let summary = "Lower the operations from the vector transfer_read to vector load";
5656
let description = [{
57-
This pass creates a transfer read op lowering. A vector trasfer read op
58-
will be lowered to a combination of vector.load, arith.select and
59-
vector.broadcast.
57+
This pass creates a transfer read op lowering optimization. The lowering
58+
will produce a conditional check at runtime. If within bounds, a vector
59+
trasfer read op will be lowered to a combination of vector.load, arith.select
60+
and vector.broadcast. If not, it will fallback to the default lowering
61+
of the transfer_read op.
6062

6163
This pattern will make it possible for masked transfer_read to be lowered
6264
towards buffer load with bounds check, allowing a more optimized global
6365
load accessing pattern compared with existing implementation of
6466
llvm.intr.masked.load on vectors.
6567
}];
66-
let dependentDialects = [];
68+
let dependentDialects = [
69+
"scf::SCFDialect",
70+
"memref::MemRefDialect"
71+
];
6772
}
6873
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
1414
MLIRAMDGPUUtils
1515
MLIRArithDialect
1616
MLIRMemRefDialect
17+
MLIRSCFDialect
1718
MLIRVectorDialect
1819
MLIRControlFlowDialect
1920
MLIRFuncDialect

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
1214
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1315
#include "mlir/IR/BuiltinTypes.h"
1416
#include "mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
108110

109111
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
110112
PatternRewriter &rewriter) const override {
113+
if (readOp->hasAttr("amdgpu.transformed"))
114+
return failure();
111115

112116
bool requiresBroadcasting = false;
113117
VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
117121
}
118122

119123
Location loc = readOp.getLoc();
120-
Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
121-
readOp.getPadding());
122-
Value load = rewriter.create<vector::LoadOp>(
123-
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
124-
Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
125-
readOp.getMask(), load, fill);
126-
127-
// Insert a broadcasting op if required.
128-
if (requiresBroadcasting) {
129-
res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
130-
res);
124+
Value src = readOp.getSource();
125+
MemRefType memRefType = cast<MemRefType>(src.getType());
126+
ArrayRef<int64_t> shape = memRefType.getShape();
127+
128+
Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
129+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
130+
Value stride = one;
131+
132+
// Compute the linear index by linearIndex += indices[i] * stride
133+
for (int i = shape.size() - 1; i >= 0; --i) {
134+
Value currentIndex = readOp.getIndices()[i];
135+
Value strideIndexed =
136+
rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
137+
linearIndex =
138+
rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
139+
140+
if (i == 0)
141+
break;
142+
143+
// Update stride for the next dimension
144+
Value nextStride;
145+
if (shape[i] != ShapedType::kDynamic) {
146+
nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
147+
} else {
148+
nextStride = rewriter.create<memref::DimOp>(loc, src, i);
149+
}
150+
stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
151+
}
152+
153+
// Add vector size offset to linear index
154+
VectorType vectorType = readOp.getVectorType();
155+
int64_t vectorSize = vectorType.getNumElements();
156+
Value vectorSizeOffset =
157+
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
158+
Value upperBoundIndex =
159+
rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
160+
161+
Value totalSize = one;
162+
for (size_t i = 0; i < shape.size(); ++i) {
163+
Value dimensionSize;
164+
if (shape[i] != ShapedType::kDynamic) {
165+
dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
166+
} else {
167+
dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
168+
}
169+
totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
131170
}
132171

133-
rewriter.replaceOp(readOp, res);
172+
Value isInBounds = rewriter.create<arith::CmpIOp>(
173+
loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
174+
175+
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
176+
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
177+
readOp.getPadding());
178+
Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
179+
readOp.getSource(),
180+
readOp.getIndices());
181+
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
182+
readOp.getMask(), load, fill);
183+
184+
// Insert a broadcasting op if required.
185+
if (requiresBroadcasting) {
186+
res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
187+
res);
188+
}
189+
rewriter.create<scf::YieldOp>(loc, res);
190+
};
191+
192+
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
193+
Operation *read = builder.clone(*readOp.getOperation());
194+
read->setAttr("amdgpu.transformed", builder.getUnitAttr());
195+
Value readResult = read->getResult(0);
196+
builder.create<scf::YieldOp>(loc, readResult);
197+
};
198+
199+
auto ifOp =
200+
rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
201+
202+
rewriter.replaceOp(readOp, ifOp);
134203

135204
return success();
136205
}

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,54 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
1010
return %res : vector<4xf32>
1111
}
1212
// CHECK: %[[CST:.*]] = arith.constant 0.0
13+
// CHECK: %[[C0:.*]] = arith.constant 0
14+
// CHECK: %[[C1:.*]] = arith.constant 1
15+
// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
16+
// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
17+
// CHECK: %[[C8:.*]] = arith.constant 8
18+
// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
19+
// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
20+
// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
21+
// CHECK: %[[C4:.*]] = arith.constant 4
22+
// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
23+
24+
// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
25+
// CHECK: %[[MUL4:.*]] = arith.muli
26+
27+
// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
28+
// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
29+
1330
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
1431
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
1532
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
16-
// CHECK: return %[[SELECT]] : vector<4xf32>
33+
34+
// CHECK: } else {
35+
// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
36+
37+
// CHECK: return %[[IF]] : vector<4xf32>
38+
39+
// -----
40+
41+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
42+
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
43+
// CHECK-SAME: %[[ARG1:.*]]: index
44+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
45+
func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
46+
%cf0 = arith.constant 0.0 : f32
47+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
48+
return %res : vector<4xf32>
49+
}
50+
51+
// CHECK: %[[C1:.*]] = arith.constant 1
52+
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
53+
// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
54+
// CHECK: %[[C0:.*]] = arith.constant 0
55+
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
56+
// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
57+
58+
// CHECK: %[[C1_1:.*]] = arith.constant 1
59+
// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
60+
// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
1761

1862
// -----
1963

@@ -64,7 +108,6 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
64108
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
65109
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
66110
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
67-
// CHECK: return %[[BROADCAST]] : vector<4xf32>
68111

69112
// -----
70113

@@ -83,4 +126,3 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
83126
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
84127
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
85128
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
86-
// CHECK: return %[[SELECT]] : vector<1xf32>

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,7 @@ cc_library(
15681568
":IR",
15691569
":MemRefDialect",
15701570
":Pass",
1571+
":SCFDialect",
15711572
":SideEffectInterfaces",
15721573
":Support",
15731574
":TransformUtils",

0 commit comments

Comments
 (0)