Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <memory>

namespace mlir {
class ConversionTarget;
class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
Expand All @@ -19,7 +20,8 @@ class Pass;
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"

void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
RewritePatternSet &patterns);

void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
Expand Down
292 changes: 181 additions & 111 deletions mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
return os.str();
}

template <bool isLoad, typename OpType>
int32_t getL1CacheControl(OpType op) {
static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
if constexpr (isLoad) {
switch (*op.getCacheControl()) {
case LoadCacheControl::L1UC_L2UC_L3UC:
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3UC:
case LoadCacheControl::L1UC_L2C_L3C:
control = 1;
break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
control = 2;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
control = 3;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 4;
break;
}
} else {
switch (*op.getCacheControl()) {
case StoreCacheControl::L1UC_L2UC_L3UC:
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3UC:
case StoreCacheControl::L1UC_L2WB_L3WB:
control = 1;
break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 4;
break;
}
switch (cc) {
case LoadCacheControl::L1UC_L2UC_L3UC:
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3UC:
case LoadCacheControl::L1UC_L2C_L3C:
control = 1;
break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
control = 2;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
control = 3;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 4;
break;
}
return control;
}

template <bool isLoad, typename OpType>
int32_t getL3CacheControl(OpType op) {
static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
if constexpr (isLoad) {
switch (*op.getCacheControl()) {
case LoadCacheControl::L1UC_L2UC_L3UC:
case LoadCacheControl::L1UC_L2C_L3UC:
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2C_L3UC:
control = 1;
break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 4;
break;
}
} else {
switch (*op.getCacheControl()) {
case StoreCacheControl::L1UC_L2UC_L3UC:
case StoreCacheControl::L1UC_L2WB_L3UC:
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
control = 1;
break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3WB:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3WB:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
}
switch (cc) {
case StoreCacheControl::L1UC_L2UC_L3UC:
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3UC:
case StoreCacheControl::L1UC_L2WB_L3WB:
control = 1;
break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 4;
break;
}
return control;
}

template <bool isLoad, typename OpType>
static int32_t getL3CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
case LoadCacheControl::L1UC_L2UC_L3UC:
case LoadCacheControl::L1UC_L2C_L3UC:
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2C_L3UC:
control = 1;
break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 4;
break;
}
return control;
}

static int32_t getL3CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
case StoreCacheControl::L1UC_L2UC_L3UC:
case StoreCacheControl::L1UC_L2WB_L3UC:
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
control = 1;
break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3WB:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3WB:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
}
return control;
}

static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we collapse these functions into a type-restricted templated version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but hot sure how to do it. The case is a bit tricky because return type depends on input type, but not something that can be inferred by the compiler or easily expressed with declspec.
If there is a clever way, will do in a separate PR alone with attr name mentioned above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I meant it at least for cases that have a matching return type, a separate PR is fine.

return op.getCacheControl();
}

static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}

static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}

static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}

static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
if (!attr)
return std::nullopt;
return std::optional<LoadCacheControl>(attr.getValue());
}
return std::nullopt;
}

static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
if (!attr)
return std::nullopt;
return std::optional<StoreCacheControl>(attr.getValue());
}
return std::nullopt;
}

template <typename OpType>
int32_t getL1CacheControl(OpType op) {
return getL1CacheControl(*getCacheControl(op));
}

template <typename OpType>
int32_t getL3CacheControl(OpType op) {
return getL3CacheControl(*getCacheControl(op));
}

template <typename OpType>
static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!op.getCacheControl())
if (!getCacheControl(op))
return {};
constexpr int32_t decorationCacheControlArity{4};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
controlKey, 0, getL1CacheControl<OpType>(op), 0};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
controlKey, 1, getL3CacheControl<OpType>(op), 0};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);

Expand Down Expand Up @@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata<true>(rewriter, op))
getCacheControlMetadata(rewriter, op))
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
getCacheControlMetadata(rewriter, op)) {
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
}
if constexpr (isLoad)
Expand All @@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
return success();
}
};
template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about defining this attr name similarly to XeVMDialect::getCacheControlsAttrName()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, and I also thought about it but that would touch dialect files.
Will do in a separate PR to keep this one small.

return failure();
std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata(rewriter, op);
op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
op->removeAttr("cache_control");
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass Definition
Expand All @@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass

void runOnOperation() override {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<XeVMDialect>();
RewritePatternSet patterns(&getContext());
populateXeVMToLLVMConversionPatterns(patterns);
populateXeVMToLLVMConversionPatterns(target, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
Expand All @@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateXeVMToLLVMConversionPatterns(patterns);
populateXeVMToLLVMConversionPatterns(target, patterns);
}
};
} // namespace
Expand All @@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
// Pattern Population
//===----------------------------------------------------------------------===//

void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
RewritePatternSet &patterns) {
target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
[](Operation *op) { return !op->hasAttr("cache_control"); });
target.addIllegalDialect<XeVMDialect>();
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
patterns.getContext());
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
}

void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ class XeVMDialectLLVMIRTranslationInterface
return handleDecorationCacheControl(instructions.front(),
cacheControlsArray.getValue());
}
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();

return success();
}

Expand Down
Loading