Skip to content

Commit 3f2e3e6

Browse files
authored
[mlir][AMDGPU][NFC] Fix overlapping masked load refinements (#159805)
The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource. This commet add checks to resolve the overlap.
1 parent c50802c commit 3f2e3e6

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
3333

3434
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
3535
/// and `arith.select` if the memref is in buffer address space.
36-
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
37-
vector::MaskedLoadOp maskedOp) {
38-
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
36+
static LogicalResult hasBufferAddressSpace(Type type) {
37+
auto memRefType = dyn_cast<MemRefType>(type);
3938
if (!memRefType)
40-
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
39+
return failure();
4140

4241
Attribute addrSpace = memRefType.getMemorySpace();
4342
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
44-
return rewriter.notifyMatchFailure(maskedOp, "no address space");
43+
return failure();
4544

4645
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
4746
amdgpu::AddressSpace::FatRawBuffer)
48-
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
47+
return failure();
4948

5049
return success();
5150
}
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
8382
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
8483
PatternRewriter &rewriter) const override {
8584
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
86-
return failure();
85+
return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
8786

88-
if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
89-
return failure();
87+
if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
88+
return rewriter.notifyMatchFailure(
89+
maskedOp, "isn't a load from a fat buffer resource");
9090
}
9191

9292
// Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
176176

177177
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
178178
PatternRewriter &rewriter) const override {
179+
if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
180+
return rewriter.notifyMatchFailure(
181+
loadOp, "buffer loads are handled by a more specialized pattern");
182+
179183
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
180184
if (failed(maybeCond)) {
181-
return failure();
185+
return rewriter.notifyMatchFailure(loadOp,
186+
"isn't loading a broadcasted scalar");
182187
}
183188

184189
Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
203208

204209
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
205210
PatternRewriter &rewriter) const override {
211+
// A condition-free implementation of fully masked stores requires
212+
// 1) an accessor for the num_records field on buffer resources/fat pointers
213+
// 2) knowledge that said field will always be set accurately - that is,
214+
// that writes to x < num_records of offset wouldn't trap, which is
215+
// something a pattern user would need to assert or we'd need to prove.
216+
//
217+
// Therefore, conditional stores to buffers still go down this path at
218+
// present.
219+
206220
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
207221
if (failed(maybeCond)) {
208222
return failure();

0 commit comments

Comments
 (0)