-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeVM] Add lowering for llvm load store ops with XeVM cache control #156768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
669d09d
733d437
8ae4428
695e141
ab9840b
760b031
3ed1f6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types, | |
| return os.str(); | ||
| } | ||
|
|
||
| template <bool isLoad, typename OpType> | ||
| int32_t getL1CacheControl(OpType op) { | ||
| static int32_t getL1CacheControl(LoadCacheControl cc) { | ||
| int32_t control = 0; | ||
| if constexpr (isLoad) { | ||
| switch (*op.getCacheControl()) { | ||
| case LoadCacheControl::L1UC_L2UC_L3UC: | ||
| case LoadCacheControl::L1UC_L2UC_L3C: | ||
| case LoadCacheControl::L1UC_L2C_L3UC: | ||
| case LoadCacheControl::L1UC_L2C_L3C: | ||
| control = 1; | ||
| break; | ||
| case LoadCacheControl::L1C_L2UC_L3UC: | ||
| case LoadCacheControl::L1C_L2UC_L3C: | ||
| case LoadCacheControl::L1C_L2C_L3UC: | ||
| case LoadCacheControl::L1C_L2C_L3C: | ||
| control = 2; | ||
| break; | ||
| case LoadCacheControl::L1S_L2UC_L3UC: | ||
| case LoadCacheControl::L1S_L2UC_L3C: | ||
| case LoadCacheControl::L1S_L2C_L3UC: | ||
| case LoadCacheControl::L1S_L2C_L3C: | ||
| control = 3; | ||
| break; | ||
| case LoadCacheControl::INVALIDATE_READ: | ||
| control = 4; | ||
| break; | ||
| } | ||
| } else { | ||
| switch (*op.getCacheControl()) { | ||
| case StoreCacheControl::L1UC_L2UC_L3UC: | ||
| case StoreCacheControl::L1UC_L2UC_L3WB: | ||
| case StoreCacheControl::L1UC_L2WB_L3UC: | ||
| case StoreCacheControl::L1UC_L2WB_L3WB: | ||
| control = 1; | ||
| break; | ||
| case StoreCacheControl::L1WT_L2UC_L3UC: | ||
| case StoreCacheControl::L1WT_L2UC_L3WB: | ||
| case StoreCacheControl::L1WT_L2WB_L3UC: | ||
| case StoreCacheControl::L1WT_L2WB_L3WB: | ||
| control = 2; | ||
| break; | ||
| case StoreCacheControl::L1S_L2UC_L3UC: | ||
| case StoreCacheControl::L1S_L2UC_L3WB: | ||
| case StoreCacheControl::L1S_L2WB_L3UC: | ||
| case StoreCacheControl::L1S_L2WB_L3WB: | ||
| control = 3; | ||
| break; | ||
| case StoreCacheControl::L1WB_L2UC_L3UC: | ||
| case StoreCacheControl::L1WB_L2WB_L3UC: | ||
| case StoreCacheControl::L1WB_L2UC_L3WB: | ||
| control = 4; | ||
| break; | ||
| } | ||
| switch (cc) { | ||
| case LoadCacheControl::L1UC_L2UC_L3UC: | ||
| case LoadCacheControl::L1UC_L2UC_L3C: | ||
| case LoadCacheControl::L1UC_L2C_L3UC: | ||
| case LoadCacheControl::L1UC_L2C_L3C: | ||
| control = 1; | ||
| break; | ||
| case LoadCacheControl::L1C_L2UC_L3UC: | ||
| case LoadCacheControl::L1C_L2UC_L3C: | ||
| case LoadCacheControl::L1C_L2C_L3UC: | ||
| case LoadCacheControl::L1C_L2C_L3C: | ||
| control = 2; | ||
| break; | ||
| case LoadCacheControl::L1S_L2UC_L3UC: | ||
| case LoadCacheControl::L1S_L2UC_L3C: | ||
| case LoadCacheControl::L1S_L2C_L3UC: | ||
| case LoadCacheControl::L1S_L2C_L3C: | ||
| control = 3; | ||
| break; | ||
| case LoadCacheControl::INVALIDATE_READ: | ||
| control = 4; | ||
| break; | ||
| } | ||
| return control; | ||
| } | ||
|
|
||
| template <bool isLoad, typename OpType> | ||
| int32_t getL3CacheControl(OpType op) { | ||
| static int32_t getL1CacheControl(StoreCacheControl cc) { | ||
| int32_t control = 0; | ||
| if constexpr (isLoad) { | ||
| switch (*op.getCacheControl()) { | ||
| case LoadCacheControl::L1UC_L2UC_L3UC: | ||
| case LoadCacheControl::L1UC_L2C_L3UC: | ||
| case LoadCacheControl::L1C_L2UC_L3UC: | ||
| case LoadCacheControl::L1C_L2C_L3UC: | ||
| case LoadCacheControl::L1S_L2UC_L3UC: | ||
| case LoadCacheControl::L1S_L2C_L3UC: | ||
| control = 1; | ||
| break; | ||
| case LoadCacheControl::L1UC_L2UC_L3C: | ||
| case LoadCacheControl::L1UC_L2C_L3C: | ||
| case LoadCacheControl::L1C_L2UC_L3C: | ||
| case LoadCacheControl::L1C_L2C_L3C: | ||
| case LoadCacheControl::L1S_L2UC_L3C: | ||
| case LoadCacheControl::L1S_L2C_L3C: | ||
| control = 2; | ||
| break; | ||
| case LoadCacheControl::INVALIDATE_READ: | ||
| control = 4; | ||
| break; | ||
| } | ||
| } else { | ||
| switch (*op.getCacheControl()) { | ||
| case StoreCacheControl::L1UC_L2UC_L3UC: | ||
| case StoreCacheControl::L1UC_L2WB_L3UC: | ||
| case StoreCacheControl::L1WT_L2UC_L3UC: | ||
| case StoreCacheControl::L1WT_L2WB_L3UC: | ||
| case StoreCacheControl::L1S_L2UC_L3UC: | ||
| case StoreCacheControl::L1S_L2WB_L3UC: | ||
| case StoreCacheControl::L1WB_L2UC_L3UC: | ||
| case StoreCacheControl::L1WB_L2WB_L3UC: | ||
| control = 1; | ||
| break; | ||
| case StoreCacheControl::L1UC_L2UC_L3WB: | ||
| case StoreCacheControl::L1UC_L2WB_L3WB: | ||
| case StoreCacheControl::L1WT_L2UC_L3WB: | ||
| case StoreCacheControl::L1WT_L2WB_L3WB: | ||
| case StoreCacheControl::L1S_L2UC_L3WB: | ||
| case StoreCacheControl::L1S_L2WB_L3WB: | ||
| case StoreCacheControl::L1WB_L2UC_L3WB: | ||
| control = 2; | ||
| break; | ||
| } | ||
| switch (cc) { | ||
| case StoreCacheControl::L1UC_L2UC_L3UC: | ||
| case StoreCacheControl::L1UC_L2UC_L3WB: | ||
| case StoreCacheControl::L1UC_L2WB_L3UC: | ||
| case StoreCacheControl::L1UC_L2WB_L3WB: | ||
| control = 1; | ||
| break; | ||
| case StoreCacheControl::L1WT_L2UC_L3UC: | ||
| case StoreCacheControl::L1WT_L2UC_L3WB: | ||
| case StoreCacheControl::L1WT_L2WB_L3UC: | ||
| case StoreCacheControl::L1WT_L2WB_L3WB: | ||
| control = 2; | ||
| break; | ||
| case StoreCacheControl::L1S_L2UC_L3UC: | ||
| case StoreCacheControl::L1S_L2UC_L3WB: | ||
| case StoreCacheControl::L1S_L2WB_L3UC: | ||
| case StoreCacheControl::L1S_L2WB_L3WB: | ||
| control = 3; | ||
| break; | ||
| case StoreCacheControl::L1WB_L2UC_L3UC: | ||
| case StoreCacheControl::L1WB_L2WB_L3UC: | ||
| case StoreCacheControl::L1WB_L2UC_L3WB: | ||
| control = 4; | ||
| break; | ||
| } | ||
| return control; | ||
| } | ||
|
|
||
| template <bool isLoad, typename OpType> | ||
| static int32_t getL3CacheControl(LoadCacheControl cc) { | ||
| int32_t control = 0; | ||
| switch (cc) { | ||
| case LoadCacheControl::L1UC_L2UC_L3UC: | ||
| case LoadCacheControl::L1UC_L2C_L3UC: | ||
| case LoadCacheControl::L1C_L2UC_L3UC: | ||
| case LoadCacheControl::L1C_L2C_L3UC: | ||
| case LoadCacheControl::L1S_L2UC_L3UC: | ||
| case LoadCacheControl::L1S_L2C_L3UC: | ||
| control = 1; | ||
| break; | ||
| case LoadCacheControl::L1UC_L2UC_L3C: | ||
| case LoadCacheControl::L1UC_L2C_L3C: | ||
| case LoadCacheControl::L1C_L2UC_L3C: | ||
| case LoadCacheControl::L1C_L2C_L3C: | ||
| case LoadCacheControl::L1S_L2UC_L3C: | ||
| case LoadCacheControl::L1S_L2C_L3C: | ||
| control = 2; | ||
| break; | ||
| case LoadCacheControl::INVALIDATE_READ: | ||
| control = 4; | ||
| break; | ||
| } | ||
| return control; | ||
| } | ||
|
|
||
| static int32_t getL3CacheControl(StoreCacheControl cc) { | ||
| int32_t control = 0; | ||
| switch (cc) { | ||
| case StoreCacheControl::L1UC_L2UC_L3UC: | ||
| case StoreCacheControl::L1UC_L2WB_L3UC: | ||
| case StoreCacheControl::L1WT_L2UC_L3UC: | ||
| case StoreCacheControl::L1WT_L2WB_L3UC: | ||
| case StoreCacheControl::L1S_L2UC_L3UC: | ||
| case StoreCacheControl::L1S_L2WB_L3UC: | ||
| case StoreCacheControl::L1WB_L2UC_L3UC: | ||
| case StoreCacheControl::L1WB_L2WB_L3UC: | ||
| control = 1; | ||
| break; | ||
| case StoreCacheControl::L1UC_L2UC_L3WB: | ||
| case StoreCacheControl::L1UC_L2WB_L3WB: | ||
| case StoreCacheControl::L1WT_L2UC_L3WB: | ||
| case StoreCacheControl::L1WT_L2WB_L3WB: | ||
| case StoreCacheControl::L1S_L2UC_L3WB: | ||
| case StoreCacheControl::L1S_L2WB_L3WB: | ||
| case StoreCacheControl::L1WB_L2UC_L3WB: | ||
| control = 2; | ||
| break; | ||
| } | ||
| return control; | ||
| } | ||
|
|
||
| static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) { | ||
| return op.getCacheControl(); | ||
| } | ||
|
|
||
| static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) { | ||
| return op.getCacheControl(); | ||
| } | ||
|
|
||
| static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) { | ||
| return op.getCacheControl(); | ||
| } | ||
|
|
||
| static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) { | ||
| return op.getCacheControl(); | ||
| } | ||
|
|
||
| static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) { | ||
| if (op->hasAttr("cache_control")) { | ||
| auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control"); | ||
| if (!attr) | ||
| return std::nullopt; | ||
| return std::optional<LoadCacheControl>(attr.getValue()); | ||
| } | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) { | ||
| if (op->hasAttr("cache_control")) { | ||
| auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control"); | ||
| if (!attr) | ||
| return std::nullopt; | ||
| return std::optional<StoreCacheControl>(attr.getValue()); | ||
| } | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| template <typename OpType> | ||
| int32_t getL1CacheControl(OpType op) { | ||
| return getL1CacheControl(*getCacheControl(op)); | ||
| } | ||
|
|
||
| template <typename OpType> | ||
| int32_t getL3CacheControl(OpType op) { | ||
| return getL3CacheControl(*getCacheControl(op)); | ||
| } | ||
|
|
||
| template <typename OpType> | ||
| static std::optional<ArrayAttr> | ||
| getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { | ||
| if (!op.getCacheControl()) | ||
| if (!getCacheControl(op)) | ||
| return {}; | ||
| constexpr int32_t decorationCacheControlArity{4}; | ||
| constexpr int32_t loadCacheControlKey{6442}; | ||
| constexpr int32_t storeCacheControlKey{6443}; | ||
| constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> || | ||
| std::is_same_v<OpType, BlockPrefetch2dOp> || | ||
| std::is_same_v<OpType, LLVM::LoadOp> || | ||
| std::is_same_v<OpType, PrefetchOp>; | ||
| const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; | ||
| SmallVector<int32_t, decorationCacheControlArity> decorationsL1{ | ||
| controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0}; | ||
| controlKey, 0, getL1CacheControl<OpType>(op), 0}; | ||
| SmallVector<int32_t, decorationCacheControlArity> decorationsL3{ | ||
| controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0}; | ||
| controlKey, 1, getL3CacheControl<OpType>(op), 0}; | ||
| auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); | ||
| auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); | ||
|
|
||
|
|
@@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { | |
| rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), | ||
| argTypes, args, {}, funcAttr, op.getOperation()); | ||
| if (std::optional<ArrayAttr> optCacheControls = | ||
| getCacheControlMetadata<true>(rewriter, op)) | ||
| getCacheControlMetadata(rewriter, op)) | ||
| call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
|
|
@@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { | |
| rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), | ||
| argTypes, args, paramAttrs, funcAttr, op.getOperation()); | ||
| if (std::optional<ArrayAttr> optCacheControls = | ||
| getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { | ||
| getCacheControlMetadata(rewriter, op)) { | ||
| call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); | ||
| } | ||
| if constexpr (isLoad) | ||
|
|
@@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { | |
| return success(); | ||
| } | ||
| }; | ||
| template <typename OpType> | ||
| class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { | ||
| using OpConversionPattern<OpType>::OpConversionPattern; | ||
| LogicalResult | ||
| matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| if (!op->hasAttr("cache_control")) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about defining this attr name similarly to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, and I also thought about it but that would touch dialect files. |
||
| return failure(); | ||
| std::optional<ArrayAttr> optCacheControls = | ||
| getCacheControlMetadata(rewriter, op); | ||
| op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); | ||
| op->removeAttr("cache_control"); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Pass Definition | ||
|
|
@@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass | |
|
|
||
| void runOnOperation() override { | ||
| ConversionTarget target(getContext()); | ||
| target.addLegalDialect<LLVM::LLVMDialect>(); | ||
| target.addIllegalDialect<XeVMDialect>(); | ||
| RewritePatternSet patterns(&getContext()); | ||
| populateXeVMToLLVMConversionPatterns(patterns); | ||
| populateXeVMToLLVMConversionPatterns(target, patterns); | ||
| if (failed(applyPartialConversion(getOperation(), target, | ||
| std::move(patterns)))) | ||
| signalPassFailure(); | ||
|
|
@@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { | |
| void populateConvertToLLVMConversionPatterns( | ||
| ConversionTarget &target, LLVMTypeConverter &typeConverter, | ||
| RewritePatternSet &patterns) const final { | ||
| populateXeVMToLLVMConversionPatterns(patterns); | ||
| populateXeVMToLLVMConversionPatterns(target, patterns); | ||
| } | ||
| }; | ||
| } // namespace | ||
|
|
@@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { | |
| // Pattern Population | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { | ||
| void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, | ||
| RewritePatternSet &patterns) { | ||
| target.addDynamicallyLegalDialect<LLVM::LLVMDialect>( | ||
| [](Operation *op) { return !op->hasAttr("cache_control"); }); | ||
| target.addIllegalDialect<XeVMDialect>(); | ||
| patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>, | ||
| LoadStorePrefetchToOCLPattern<BlockStore2dOp>, | ||
| LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, | ||
| MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( | ||
| patterns.getContext()); | ||
| MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern, | ||
| LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, | ||
| LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext()); | ||
| } | ||
|
|
||
| void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we collapse these functions into a type-restricted templated version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried but hot sure how to do it. The case is a bit tricky because return type depends on input type, but not something that can be inferred by the compiler or easily expressed with
declspec.If there is a clever way, will do in a separate PR alone with attr name mentioned above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I meant it at least for cases that have a matching return type, a separate PR is fine.