@@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
9898 return os.str ();
9999}
100100
101- template <bool isLoad, typename OpType>
102- int32_t getL1CacheControl (OpType op) {
101+ static int32_t getL1CacheControl (LoadCacheControl cc) {
103102 int32_t control = 0 ;
104- if constexpr (isLoad) {
105- switch (*op.getCacheControl ()) {
106- case LoadCacheControl::L1UC_L2UC_L3UC:
107- case LoadCacheControl::L1UC_L2UC_L3C:
108- case LoadCacheControl::L1UC_L2C_L3UC:
109- case LoadCacheControl::L1UC_L2C_L3C:
110- control = 1 ;
111- break ;
112- case LoadCacheControl::L1C_L2UC_L3UC:
113- case LoadCacheControl::L1C_L2UC_L3C:
114- case LoadCacheControl::L1C_L2C_L3UC:
115- case LoadCacheControl::L1C_L2C_L3C:
116- control = 2 ;
117- break ;
118- case LoadCacheControl::L1S_L2UC_L3UC:
119- case LoadCacheControl::L1S_L2UC_L3C:
120- case LoadCacheControl::L1S_L2C_L3UC:
121- case LoadCacheControl::L1S_L2C_L3C:
122- control = 3 ;
123- break ;
124- case LoadCacheControl::INVALIDATE_READ:
125- control = 4 ;
126- break ;
127- }
128- } else {
129- switch (*op.getCacheControl ()) {
130- case StoreCacheControl::L1UC_L2UC_L3UC:
131- case StoreCacheControl::L1UC_L2UC_L3WB:
132- case StoreCacheControl::L1UC_L2WB_L3UC:
133- case StoreCacheControl::L1UC_L2WB_L3WB:
134- control = 1 ;
135- break ;
136- case StoreCacheControl::L1WT_L2UC_L3UC:
137- case StoreCacheControl::L1WT_L2UC_L3WB:
138- case StoreCacheControl::L1WT_L2WB_L3UC:
139- case StoreCacheControl::L1WT_L2WB_L3WB:
140- control = 2 ;
141- break ;
142- case StoreCacheControl::L1S_L2UC_L3UC:
143- case StoreCacheControl::L1S_L2UC_L3WB:
144- case StoreCacheControl::L1S_L2WB_L3UC:
145- case StoreCacheControl::L1S_L2WB_L3WB:
146- control = 3 ;
147- break ;
148- case StoreCacheControl::L1WB_L2UC_L3UC:
149- case StoreCacheControl::L1WB_L2WB_L3UC:
150- case StoreCacheControl::L1WB_L2UC_L3WB:
151- control = 4 ;
152- break ;
153- }
103+ switch (cc) {
104+ case LoadCacheControl::L1UC_L2UC_L3UC:
105+ case LoadCacheControl::L1UC_L2UC_L3C:
106+ case LoadCacheControl::L1UC_L2C_L3UC:
107+ case LoadCacheControl::L1UC_L2C_L3C:
108+ control = 1 ;
109+ break ;
110+ case LoadCacheControl::L1C_L2UC_L3UC:
111+ case LoadCacheControl::L1C_L2UC_L3C:
112+ case LoadCacheControl::L1C_L2C_L3UC:
113+ case LoadCacheControl::L1C_L2C_L3C:
114+ control = 2 ;
115+ break ;
116+ case LoadCacheControl::L1S_L2UC_L3UC:
117+ case LoadCacheControl::L1S_L2UC_L3C:
118+ case LoadCacheControl::L1S_L2C_L3UC:
119+ case LoadCacheControl::L1S_L2C_L3C:
120+ control = 3 ;
121+ break ;
122+ case LoadCacheControl::INVALIDATE_READ:
123+ control = 4 ;
124+ break ;
154125 }
155126 return control;
156127}
157128
158- template <bool isLoad, typename OpType>
159- int32_t getL3CacheControl (OpType op) {
129+ static int32_t getL1CacheControl (StoreCacheControl cc) {
160130 int32_t control = 0 ;
161- if constexpr (isLoad) {
162- switch (*op.getCacheControl ()) {
163- case LoadCacheControl::L1UC_L2UC_L3UC:
164- case LoadCacheControl::L1UC_L2C_L3UC:
165- case LoadCacheControl::L1C_L2UC_L3UC:
166- case LoadCacheControl::L1C_L2C_L3UC:
167- case LoadCacheControl::L1S_L2UC_L3UC:
168- case LoadCacheControl::L1S_L2C_L3UC:
169- control = 1 ;
170- break ;
171- case LoadCacheControl::L1UC_L2UC_L3C:
172- case LoadCacheControl::L1UC_L2C_L3C:
173- case LoadCacheControl::L1C_L2UC_L3C:
174- case LoadCacheControl::L1C_L2C_L3C:
175- case LoadCacheControl::L1S_L2UC_L3C:
176- case LoadCacheControl::L1S_L2C_L3C:
177- control = 2 ;
178- break ;
179- case LoadCacheControl::INVALIDATE_READ:
180- control = 4 ;
181- break ;
182- }
183- } else {
184- switch (*op.getCacheControl ()) {
185- case StoreCacheControl::L1UC_L2UC_L3UC:
186- case StoreCacheControl::L1UC_L2WB_L3UC:
187- case StoreCacheControl::L1WT_L2UC_L3UC:
188- case StoreCacheControl::L1WT_L2WB_L3UC:
189- case StoreCacheControl::L1S_L2UC_L3UC:
190- case StoreCacheControl::L1S_L2WB_L3UC:
191- case StoreCacheControl::L1WB_L2UC_L3UC:
192- case StoreCacheControl::L1WB_L2WB_L3UC:
193- control = 1 ;
194- break ;
195- case StoreCacheControl::L1UC_L2UC_L3WB:
196- case StoreCacheControl::L1UC_L2WB_L3WB:
197- case StoreCacheControl::L1WT_L2UC_L3WB:
198- case StoreCacheControl::L1WT_L2WB_L3WB:
199- case StoreCacheControl::L1S_L2UC_L3WB:
200- case StoreCacheControl::L1S_L2WB_L3WB:
201- case StoreCacheControl::L1WB_L2UC_L3WB:
202- control = 2 ;
203- break ;
204- }
131+ switch (cc) {
132+ case StoreCacheControl::L1UC_L2UC_L3UC:
133+ case StoreCacheControl::L1UC_L2UC_L3WB:
134+ case StoreCacheControl::L1UC_L2WB_L3UC:
135+ case StoreCacheControl::L1UC_L2WB_L3WB:
136+ control = 1 ;
137+ break ;
138+ case StoreCacheControl::L1WT_L2UC_L3UC:
139+ case StoreCacheControl::L1WT_L2UC_L3WB:
140+ case StoreCacheControl::L1WT_L2WB_L3UC:
141+ case StoreCacheControl::L1WT_L2WB_L3WB:
142+ control = 2 ;
143+ break ;
144+ case StoreCacheControl::L1S_L2UC_L3UC:
145+ case StoreCacheControl::L1S_L2UC_L3WB:
146+ case StoreCacheControl::L1S_L2WB_L3UC:
147+ case StoreCacheControl::L1S_L2WB_L3WB:
148+ control = 3 ;
149+ break ;
150+ case StoreCacheControl::L1WB_L2UC_L3UC:
151+ case StoreCacheControl::L1WB_L2WB_L3UC:
152+ case StoreCacheControl::L1WB_L2UC_L3WB:
153+ control = 4 ;
154+ break ;
205155 }
206156 return control;
207157}
208158
209- template <bool isLoad, typename OpType>
159+ static int32_t getL3CacheControl (LoadCacheControl cc) {
160+ int32_t control = 0 ;
161+ switch (cc) {
162+ case LoadCacheControl::L1UC_L2UC_L3UC:
163+ case LoadCacheControl::L1UC_L2C_L3UC:
164+ case LoadCacheControl::L1C_L2UC_L3UC:
165+ case LoadCacheControl::L1C_L2C_L3UC:
166+ case LoadCacheControl::L1S_L2UC_L3UC:
167+ case LoadCacheControl::L1S_L2C_L3UC:
168+ control = 1 ;
169+ break ;
170+ case LoadCacheControl::L1UC_L2UC_L3C:
171+ case LoadCacheControl::L1UC_L2C_L3C:
172+ case LoadCacheControl::L1C_L2UC_L3C:
173+ case LoadCacheControl::L1C_L2C_L3C:
174+ case LoadCacheControl::L1S_L2UC_L3C:
175+ case LoadCacheControl::L1S_L2C_L3C:
176+ control = 2 ;
177+ break ;
178+ case LoadCacheControl::INVALIDATE_READ:
179+ control = 4 ;
180+ break ;
181+ }
182+ return control;
183+ }
184+
185+ static int32_t getL3CacheControl (StoreCacheControl cc) {
186+ int32_t control = 0 ;
187+ switch (cc) {
188+ case StoreCacheControl::L1UC_L2UC_L3UC:
189+ case StoreCacheControl::L1UC_L2WB_L3UC:
190+ case StoreCacheControl::L1WT_L2UC_L3UC:
191+ case StoreCacheControl::L1WT_L2WB_L3UC:
192+ case StoreCacheControl::L1S_L2UC_L3UC:
193+ case StoreCacheControl::L1S_L2WB_L3UC:
194+ case StoreCacheControl::L1WB_L2UC_L3UC:
195+ case StoreCacheControl::L1WB_L2WB_L3UC:
196+ control = 1 ;
197+ break ;
198+ case StoreCacheControl::L1UC_L2UC_L3WB:
199+ case StoreCacheControl::L1UC_L2WB_L3WB:
200+ case StoreCacheControl::L1WT_L2UC_L3WB:
201+ case StoreCacheControl::L1WT_L2WB_L3WB:
202+ case StoreCacheControl::L1S_L2UC_L3WB:
203+ case StoreCacheControl::L1S_L2WB_L3WB:
204+ case StoreCacheControl::L1WB_L2UC_L3WB:
205+ control = 2 ;
206+ break ;
207+ }
208+ return control;
209+ }
210+
211+ static std::optional<LoadCacheControl> getCacheControl (PrefetchOp op) {
212+ return op.getCacheControl ();
213+ }
214+
215+ static std::optional<LoadCacheControl> getCacheControl (BlockLoad2dOp op) {
216+ return op.getCacheControl ();
217+ }
218+
219+ static std::optional<LoadCacheControl> getCacheControl (BlockPrefetch2dOp op) {
220+ return op.getCacheControl ();
221+ }
222+
223+ static std::optional<StoreCacheControl> getCacheControl (BlockStore2dOp op) {
224+ return op.getCacheControl ();
225+ }
226+
227+ static std::optional<LoadCacheControl> getCacheControl (LLVM::LoadOp op) {
228+ if (op->hasAttr (" cache_control" )) {
229+ auto attr = op->getAttrOfType <xevm::LoadCacheControlAttr>(" cache_control" );
230+ if (!attr)
231+ return std::nullopt ;
232+ return std::optional<LoadCacheControl>(attr.getValue ());
233+ }
234+ return std::nullopt ;
235+ }
236+
237+ static std::optional<StoreCacheControl> getCacheControl (LLVM::StoreOp op) {
238+ if (op->hasAttr (" cache_control" )) {
239+ auto attr = op->getAttrOfType <xevm::StoreCacheControlAttr>(" cache_control" );
240+ if (!attr)
241+ return std::nullopt ;
242+ return std::optional<StoreCacheControl>(attr.getValue ());
243+ }
244+ return std::nullopt ;
245+ }
246+
247+ template <typename OpType>
248+ int32_t getL1CacheControl (OpType op) {
249+ return getL1CacheControl (*getCacheControl (op));
250+ }
251+
252+ template <typename OpType>
253+ int32_t getL3CacheControl (OpType op) {
254+ return getL3CacheControl (*getCacheControl (op));
255+ }
256+
257+ template <typename OpType>
210258static std::optional<ArrayAttr>
211259getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op) {
212- if (!op. getCacheControl ())
260+ if (!getCacheControl (op ))
213261 return {};
214262 constexpr int32_t decorationCacheControlArity{4 };
215263 constexpr int32_t loadCacheControlKey{6442 };
216264 constexpr int32_t storeCacheControlKey{6443 };
265+ constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
266+ std::is_same_v<OpType, BlockPrefetch2dOp> ||
267+ std::is_same_v<OpType, LLVM::LoadOp> ||
268+ std::is_same_v<OpType, PrefetchOp>;
217269 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
218270 SmallVector<int32_t , decorationCacheControlArity> decorationsL1{
219- controlKey, 0 , getL1CacheControl<isLoad, OpType>(op), 0 };
271+ controlKey, 0 , getL1CacheControl<OpType>(op), 0 };
220272 SmallVector<int32_t , decorationCacheControlArity> decorationsL3{
221- controlKey, 1 , getL3CacheControl<isLoad, OpType>(op), 0 };
273+ controlKey, 1 , getL3CacheControl<OpType>(op), 0 };
222274 auto arrayAttrL1 = rewriter.getI32ArrayAttr (decorationsL1);
223275 auto arrayAttrL3 = rewriter.getI32ArrayAttr (decorationsL3);
224276
@@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
398450 rewriter, fnName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
399451 argTypes, args, {}, funcAttr, op.getOperation ());
400452 if (std::optional<ArrayAttr> optCacheControls =
401- getCacheControlMetadata< true > (rewriter, op))
453+ getCacheControlMetadata (rewriter, op))
402454 call->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
403455 rewriter.eraseOp (op);
404456 return success ();
@@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
557609 rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
558610 argTypes, args, paramAttrs, funcAttr, op.getOperation ());
559611 if (std::optional<ArrayAttr> optCacheControls =
560- getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
612+ getCacheControlMetadata (rewriter, op)) {
561613 call->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
562614 }
563615 if constexpr (isLoad)
@@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
568620 return success ();
569621 }
570622};
623+ template <typename OpType>
624+ class LLVMLoadStoreToOCLPattern : public OpConversionPattern <OpType> {
625+ using OpConversionPattern<OpType>::OpConversionPattern;
626+ LogicalResult
627+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
628+ ConversionPatternRewriter &rewriter) const override {
629+ if (!op->hasAttr (" cache_control" ))
630+ return failure ();
631+ std::optional<ArrayAttr> optCacheControls =
632+ getCacheControlMetadata (rewriter, op);
633+ op->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
634+ op->removeAttr (" cache_control" );
635+ return success ();
636+ }
637+ };
571638
572639// ===----------------------------------------------------------------------===//
573640// Pass Definition
@@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass
583650
584651 void runOnOperation () override {
585652 ConversionTarget target (getContext ());
586- target.addLegalDialect <LLVM::LLVMDialect>();
587- target.addIllegalDialect <XeVMDialect>();
588653 RewritePatternSet patterns (&getContext ());
589- populateXeVMToLLVMConversionPatterns (patterns);
654+ populateXeVMToLLVMConversionPatterns (target, patterns);
590655 if (failed (applyPartialConversion (getOperation (), target,
591656 std::move (patterns))))
592657 signalPassFailure ();
@@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
611676 void populateConvertToLLVMConversionPatterns (
612677 ConversionTarget &target, LLVMTypeConverter &typeConverter,
613678 RewritePatternSet &patterns) const final {
614- populateXeVMToLLVMConversionPatterns (patterns);
679+ populateXeVMToLLVMConversionPatterns (target, patterns);
615680 }
616681};
617682} // namespace
@@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
620685// Pattern Population
621686// ===----------------------------------------------------------------------===//
622687
623- void ::mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
688+ void ::mlir::populateXeVMToLLVMConversionPatterns (ConversionTarget &target,
689+ RewritePatternSet &patterns) {
690+ target.addDynamicallyLegalDialect <LLVM::LLVMDialect>(
691+ [](Operation *op) { return !op->hasAttr (" cache_control" ); });
692+ target.addIllegalDialect <XeVMDialect>();
624693 patterns.add <LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
625694 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
626695 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
627- MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
628- patterns.getContext ());
696+ MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
697+ LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
698+ LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext ());
629699}
630700
631701void ::mlir::registerConvertXeVMToLLVMInterface (DialectRegistry ®istry) {
0 commit comments