Skip to content

Commit eac8c2b

Browse files
committed
Relaxing condition to do bounds check
1 parent 27c5497 commit eac8c2b

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@ static LogicalResult transferPreconditions(
103103
return success();
104104
}
105105

106+
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
107+
vector::TransferReadOp readOp,
108+
bool requiresBroadcasting,
109+
VectorType unbroadcastedVectorType) {
110+
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
111+
readOp.getPadding());
112+
Value load = builder.create<vector::LoadOp>(
113+
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
114+
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
115+
readOp.getMask(), load, fill);
116+
// Insert a broadcasting op if required.
117+
if (requiresBroadcasting) {
118+
res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
119+
}
120+
return res;
121+
}
122+
106123
namespace {
107124

108125
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -150,14 +167,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
150167
stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
151168
}
152169

153-
// Add vector size offset to linear index
154-
VectorType vectorType = readOp.getVectorType();
155-
int64_t vectorSize = vectorType.getNumElements();
156-
Value vectorSizeOffset =
157-
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
158-
Value upperBoundIndex =
159-
rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
160-
161170
Value totalSize = one;
162171
for (size_t i = 0; i < shape.size(); ++i) {
163172
Value dimensionSize;
@@ -169,35 +178,48 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
169178
totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
170179
}
171180

172-
Value isInBounds = rewriter.create<arith::CmpIOp>(
173-
loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
181+
// delta = bufferSize - linearizedOffset
182+
// 1) check if delta < vectorSize
183+
VectorType vectorType = readOp.getVectorType();
184+
int64_t vectorSize = vectorType.getNumElements();
185+
Value vectorSizeOffset =
186+
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
187+
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
188+
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
189+
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
190+
191+
// 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
192+
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
193+
Value deltaBytes = rewriter.create<arith::MulIOp>(
194+
loc, delta,
195+
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
196+
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
197+
loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
198+
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
199+
loc, arith::CmpIPredicate::ne,
200+
rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
201+
rewriter.create<arith::ConstantIndexOp>(loc, 0));
202+
203+
// We take the fallback of transfer_read default lowering only it is both
204+
// out-of-bounds and not word aligned.
205+
Value ifCondition =
206+
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
174207

175208
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
176-
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
177-
readOp.getPadding());
178-
Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
179-
readOp.getSource(),
180-
readOp.getIndices());
181-
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
182-
readOp.getMask(), load, fill);
183-
184-
// Insert a broadcasting op if required.
185-
if (requiresBroadcasting) {
186-
res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
187-
res);
188-
}
189-
rewriter.create<scf::YieldOp>(loc, res);
190-
};
191-
192-
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
193209
Operation *read = builder.clone(*readOp.getOperation());
194210
read->setAttr("amdgpu.transformed", builder.getUnitAttr());
195211
Value readResult = read->getResult(0);
196212
builder.create<scf::YieldOp>(loc, readResult);
197213
};
198214

215+
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
216+
Value res = createVectorLoadForMaskedLoad(
217+
builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
218+
rewriter.create<scf::YieldOp>(loc, res);
219+
};
220+
199221
auto ifOp =
200-
rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
222+
rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
201223

202224
rewriter.replaceOp(readOp, ifOp);
203225

0 commit comments

Comments
 (0)