Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,18 @@ using namespace mlir::amdgpu;

/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
vector::MaskedLoadOp maskedOp) {
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
static LogicalResult hasBufferAddressSpace(Type type) {
auto memRefType = dyn_cast<MemRefType>(type);
if (!memRefType)
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
return failure();

Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
return rewriter.notifyMatchFailure(maskedOp, "no address space");
return failure();

if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
return failure();

return success();
}
Expand Down Expand Up @@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
return failure();
return rewriter.notifyMatchFailure(maskedOp, "already rewritten");

if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
return failure();
if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
return rewriter.notifyMatchFailure(
maskedOp, "isn't a load from a fat buffer resource");
}

// Check if this is either a full inbounds load or an empty, oob load. If
Expand Down Expand Up @@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad

LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
PatternRewriter &rewriter) const override {
if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
return rewriter.notifyMatchFailure(
loadOp, "buffer loads are handled by a more specialized pattern");

FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
if (failed(maybeCond)) {
return failure();
return rewriter.notifyMatchFailure(loadOp,
"isn't loading a broadcasted scalar");
}

Value cond = maybeCond.value();
Expand All @@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore

LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
PatternRewriter &rewriter) const override {
// A condition-free implementation of fully masked stores requires
// 1) an accessor for the num_records field on buffer resources/fat pointers
// 2) knowledge that said field will always be set accurately - that is,
// that writes to x < num_records of offset wouldn't trap, which is
// something a pattern user would need to assert or we'd need to prove.
//
// Therefore, conditional stores to buffers still go down this path at
// present.

FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
if (failed(maybeCond)) {
return failure();
Expand Down