Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace amdgpu {

#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

Expand All @@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

} // namespace amdgpu
} // namespace mlir
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
];
}

def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
let summary = "Lower the operations from the vector transfer_read to vector load";
def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
let summary = "Lower the operations from the vector maskedload to vector load";
let description = [{
This pass creates a transfer read op lowering optimization. The lowering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the description still refers to transfer reads

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm good catch... Unfortunately PR just get merged. I'll leave a note to myself of fixing this next time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test names also seem like they may need updating

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep agreed!

will produce a conditional check at runtime. If within bounds, a vector
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ResolveStridedMetadata.cpp
TransferReadToLoad.cpp
MaskedloadToLoad.cpp

ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
Expand Down
167 changes: 167 additions & 0 deletions mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/MathExtras.h"

namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
} // namespace mlir::amdgpu

using namespace mlir;
using namespace mlir::amdgpu;

/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
vector::MaskedLoadOp maskedOp) {
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
if (!memRefType)
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");

Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
return rewriter.notifyMatchFailure(maskedOp, "no address space");

if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");

return success();
}

static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
vector::MaskedLoadOp maskedOp) {
VectorType vectorType = maskedOp.getVectorType();
Value load = builder.create<vector::LoadOp>(
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
Value res = builder.create<arith::SelectOp>(
loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
return res;
}

static constexpr char kMaskedloadNeedsMask[] =
"amdgpu.buffer_maskedload_needs_mask";

namespace {

struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
return failure();

if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
return failure();
}

Location loc = maskedOp.getLoc();
Value src = maskedOp.getBase();

VectorType vectorType = maskedOp.getVectorType();
int64_t vectorSize = vectorType.getNumElements();
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
SmallVector<OpFoldResult> indices = maskedOp.getIndices();

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
SmallVector<OpFoldResult> strides =
stridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
memref::LinearizedMemRefInfo linearizedInfo;
OpFoldResult linearizedIndices;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
elementBitWidth, offset, sizes,
strides, indices);

// delta = bufferSize - linearizedOffset
Value vectorSizeOffset =
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
Value linearIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
Value totalSize = getValueOrCreateConstantIndexOp(
rewriter, loc, linearizedInfo.linearizedSize);
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);

// 1) check if delta < vectorSize
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);

// 2) check if (detla % elements_per_word != 0)
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::divideCeil(32, elementBitWidth));
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne,
rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
rewriter.create<arith::ConstantIndexOp>(loc, 0));

// We take the fallback of maskedload default lowering only it is both
// out-of-bounds and not word aligned. The fallback ensures correct results
// when loading at the boundary of the buffer since buffer load returns
// inconsistent zeros for the whole word when boundary is crossed.
Value ifCondition =
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);

auto thenBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*maskedOp.getOperation());
read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
Value readResult = read->getResult(0);
builder.create<scf::YieldOp>(loc, readResult);
};

auto elseBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
rewriter.create<scf::YieldOp>(loc, res);
};

auto ifOp =
rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);

rewriter.replaceOp(maskedOp, ifOp);

return success();
}
};

} // namespace

void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
}

struct AmdgpuMaskedloadToLoadPass final
: amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateAmdgpuMaskedloadToLoadPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
Loading
Loading