1111#include " gc/Dialect/LLVMIR/XeVMDialect.h"
1212#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313#include " mlir/Conversion/LLVMCommon/Pattern.h"
14+ #include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1415#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1516#include " mlir/Pass/Pass.h"
1617#include " mlir/Support/LLVM.h"
18+ #include " llvm/Support/FormatVariadic.h"
19+
20+ #include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/Types.h"
22+
23+ #include " llvm/ADT/TypeSwitch.h"
1724
1825#define DEBUG_TYPE " xevm-to-llvm"
1926
@@ -26,6 +33,226 @@ using namespace mlir;
2633using namespace xevm ;
2734
2835namespace {
36+ struct LLVMFuncAttributeOptions {
37+ bool isConvergent = false ;
38+ bool isNoUnwind = false ;
39+ bool isWillReturn = false ;
40+ LLVM::MemoryEffectsAttr memEffectsAttr{};
41+ };
42+ static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
43+ false , true , false , {}};
44+ static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
45+ false , true , true , {}};
46+
47+ std::string getTypeMangling (Type ty, bool isUnsigned = false ) {
48+ return TypeSwitch<Type, std::string>(ty)
49+ .Case ([isUnsigned](VectorType ty) -> std::string {
50+ return " Dv" + std::to_string (ty.getNumElements ()) + " _" +
51+ getTypeMangling (ty.getElementType (), isUnsigned);
52+ })
53+ .Case ([](Float16Type) -> std::string { return " Dh" ; })
54+ .Case ([](Float32Type) -> std::string { return " f" ; })
55+ .Case ([](Float64Type) -> std::string { return " d" ; })
56+ .Case ([isUnsigned](IntegerType ty) -> std::string {
57+ switch (ty.getWidth ()) {
58+ case 8 :
59+ return isUnsigned ? " h" : " c" ;
60+ case 16 :
61+ return isUnsigned ? " t" : " s" ;
62+ case 32 :
63+ return isUnsigned ? " j" : " i" ;
64+ case 64 :
65+ return isUnsigned ? " m" : " l" ;
66+ default :
67+ llvm_unreachable (" unhandled integer type" );
68+ }
69+ });
70+ }
71+
72+ template <typename OpType>
73+ static std::optional<ArrayAttr>
74+ getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op,
75+ const bool isLoad) {
76+ if ((op.getL1CacheControlAttr () ==
77+ xevm::L1StoreCacheControlAttr::get (
78+ rewriter.getContext (), xevm::L1StoreCacheControl::DEFAULT) &&
79+ op.getL3CacheControlAttr () ==
80+ xevm::L3StoreCacheControlAttr::get (
81+ rewriter.getContext (), xevm::L3StoreCacheControl::DEFAULT)) ||
82+
83+ (op.getL1CacheControlAttr () ==
84+ xevm::L1LoadCacheControlAttr::get (
85+ rewriter.getContext (), xevm::L1LoadCacheControl::DEFAULT) &&
86+ op.getL3CacheControlAttr () ==
87+ xevm::L3LoadCacheControlAttr::get (
88+ rewriter.getContext (), xevm::L3LoadCacheControl::DEFAULT))) {
89+ return {};
90+ }
91+ constexpr int32_t decorationCacheControlArity{4 };
92+ constexpr int32_t loadCacheControlKey{6442 };
93+ constexpr int32_t storeCacheControlKey{6443 };
94+ constexpr int32_t l1Level{0 };
95+ constexpr int32_t l3Level{1 };
96+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
97+ SmallVector<int32_t , decorationCacheControlArity> decorationsL1{
98+ controlKey, l1Level, static_cast <int32_t >(op.getL1CacheControl ()), 0 };
99+ SmallVector<int32_t , decorationCacheControlArity> decorationsL3{
100+ controlKey, l3Level, static_cast <int32_t >(op.getL3CacheControl ()), 0 };
101+ auto arrayAttrL1 = rewriter.getI32ArrayAttr (decorationsL1);
102+ auto arrayAttrL3 = rewriter.getI32ArrayAttr (decorationsL3);
103+
104+ SmallVector<Attribute, 2 > combinedAttrs = {arrayAttrL1, arrayAttrL3};
105+ return rewriter.getArrayAttr (combinedAttrs);
106+ }
107+
108+ static LLVM::CallOp createDeviceFunctionCall (
109+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
110+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
111+ mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
112+ LLVMFuncAttributeOptions funcAttributeOptions) {
113+ auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
114+ MLIRContext *ctx = rewriter.getContext ();
115+ Location loc = UnknownLoc::get (ctx);
116+
117+ LLVM::LLVMFuncOp funcOp =
118+ LLVM::lookupOrCreateFn (moduleOp, funcName, argTypes, retType);
119+ funcOp.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
120+ funcOp.setConvergent (funcAttributeOptions.isConvergent );
121+ funcOp.setNoUnwind (funcAttributeOptions.isNoUnwind );
122+ funcOp.setWillReturn (funcAttributeOptions.isWillReturn );
123+
124+ if (funcAttributeOptions.memEffectsAttr )
125+ funcOp.setMemoryEffectsAttr (funcAttributeOptions.memEffectsAttr );
126+
127+ for (auto [idx, attrName] : paramAttrs)
128+ funcOp.setArgAttr (idx, attrName, rewriter.getUnitAttr ());
129+
130+ auto callOp = rewriter.create <LLVM::CallOp>(loc, funcOp, args);
131+ callOp->setAttrs (funcOp->getAttrs ());
132+
133+ return callOp;
134+ }
135+
136+ template <typename OpType>
137+ class LoadStorePrefetchToOCLPattern : public OpConversionPattern <OpType> {
138+ using OpConversionPattern<OpType>::OpConversionPattern;
139+ LogicalResult
140+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
141+ ConversionPatternRewriter &rewriter) const override {
142+ constexpr bool isLoad = std::is_same_v<OpType, xevm::BlockLoad2dOp>;
143+ constexpr bool isPrefetch = std::is_same_v<OpType, xevm::BlockPrefetch2dOp>;
144+
145+ auto loc = op.getLoc ();
146+ VectorType vecType;
147+ bool vnni = false ;
148+ bool transpose = false ;
149+ if constexpr (isLoad) {
150+ vecType = op.getRes ().getType ();
151+ vnni = op.getVnniTransform ();
152+ transpose = op.getTranspose ();
153+ } else if constexpr (!isPrefetch) {
154+ vecType = op.getStoredVal ().getType ();
155+ }
156+
157+ auto i32Type = rewriter.getI32Type ();
158+ Value byteCoord =
159+ rewriter.create <LLVM::UndefOp>(loc, VectorType::get (2 , i32Type));
160+ Value zero = rewriter.create <LLVM::ConstantOp>(
161+ loc, i32Type, rewriter.getI32IntegerAttr (0 ));
162+ Value one = rewriter.create <LLVM::ConstantOp>(
163+ loc, i32Type, rewriter.getI32IntegerAttr (1 ));
164+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
165+ loc, VectorType::get (2 , i32Type), byteCoord, op.getX (), zero);
166+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
167+ loc, VectorType::get (2 , i32Type), byteCoord, op.getY (), one);
168+ SmallVector<Value> args{op.getPtr (), op.getBaseWidth (), op.getBaseHeight (),
169+ op.getBasePitch (), byteCoord};
170+ SmallVector<Type> retTypes;
171+ Value spvLoadDstPtr;
172+ std::string funcName{" intel_sub_group_2d_block_" };
173+ std::string bitWidthId;
174+ LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
175+ SmallVector<std::pair<unsigned , mlir::StringRef>, 4 > paramAttrs;
176+ if constexpr (isPrefetch) { // Prefetch
177+ funcName += " prefetch" ;
178+ paramAttrs = {std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ())};
179+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
180+ /* other=*/ LLVM::ModRefInfo::NoModRef,
181+ /* argMem=*/ LLVM::ModRefInfo::Ref,
182+ /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
183+ auto funcAttrs = noUnwindAttrs;
184+ funcAttrs.memEffectsAttr = memAttr;
185+ } else {
186+ auto vecElemType = vecType.getElementType ();
187+ auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth ();
188+ Value numElems = rewriter.create <LLVM::ConstantOp>(
189+ loc, i32Type, vecType.getNumElements ());
190+ auto dstOrSrcPtr = rewriter.create <LLVM::AllocaOp>(
191+ loc, LLVM::LLVMPointerType::get (rewriter.getContext ()), vecElemType,
192+ numElems);
193+ args.push_back (dstOrSrcPtr);
194+ if constexpr (isLoad) { // Load
195+ funcName += " read" ;
196+ bitWidthId = getTypeMangling (vecElemType, /* isUnsigned=*/ true );
197+ if (vnni)
198+ funcName += " _transform" ;
199+ else if (transpose)
200+ funcName += " _transpose" ;
201+ spvLoadDstPtr = dstOrSrcPtr;
202+ retTypes.push_back (vecType);
203+ paramAttrs = {
204+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
205+ std::make_pair (0 , LLVM::LLVMDialect::getReadonlyAttrName ()),
206+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
207+ std::make_pair (5 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
208+ };
209+ } else { // Store
210+ funcName += " write" ;
211+ bitWidthId = (vecElemBitWidth == 32 )
212+ ? " j"
213+ : ((vecElemBitWidth == 16 ) ? " t" : " h" );
214+ rewriter.create <LLVM::StoreOp>(loc, op.getStoredVal (), dstOrSrcPtr);
215+ paramAttrs = {
216+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
217+ std::make_pair (0 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
218+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
219+ std::make_pair (5 , LLVM::LLVMDialect::getReadonlyAttrName ()),
220+ };
221+ }
222+ }
223+
224+ funcName =
225+ llvm::formatv (" {0}_{1}b_{2}r{3}x{4}c" , funcName, op.getElemSizeInBits (),
226+ op.getTileHeight (), op.getTileWidth (), op.getVBlocks ())
227+ .str ();
228+ funcName = llvm::formatv (" _Z{0}{1}PU3AS1viiiDv2_i{2}{3}" , funcName.size (),
229+ funcName, isPrefetch ? " " : " P" , bitWidthId)
230+ .str ();
231+ SmallVector<Type> argTypes;
232+ for (auto arg : args) {
233+ argTypes.push_back (arg.getType ());
234+ }
235+ LLVM::CallOp call = createDeviceFunctionCall (
236+ rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
237+ argTypes, args, paramAttrs, funcAttr);
238+ if (std::optional<ArrayAttr> optCacheControls =
239+ getCacheControlMetadata (rewriter, op, isLoad || isPrefetch)) {
240+ call->setAttr (xevm::XeVMDialect::getCacheControlsAttrName (),
241+ *optCacheControls);
242+ }
243+ if constexpr (isLoad)
244+ rewriter.replaceOp (
245+ op, rewriter.create <LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
246+ else
247+ rewriter.eraseOp (op);
248+ return success ();
249+ }
250+ };
251+
252+ // ===----------------------------------------------------------------------===//
253+ // Pass Definition
254+ // ===----------------------------------------------------------------------===//
255+
29256struct ConvertXeVMToLLVMPass
30257 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
31258 using Base::Base;
@@ -37,19 +264,51 @@ struct ConvertXeVMToLLVMPass
37264 void runOnOperation () override {
38265 ConversionTarget target (getContext ());
39266 target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
40- RewritePatternSet pattern (&getContext ());
41- mlir::populateXeVMToLLVMConversionPatterns (pattern);
42- if (failed (
43- applyPartialConversion (getOperation (), target, std::move (pattern))))
267+ target.addIllegalDialect <xevm::XeVMDialect>();
268+ RewritePatternSet patterns (&getContext ());
269+ mlir::populateXeVMToLLVMConversionPatterns (patterns);
270+ if (failed (applyPartialConversion (getOperation (), target,
271+ std::move (patterns))))
44272 signalPassFailure ();
45273 }
46274};
47275} // namespace
48276
277+ // ===----------------------------------------------------------------------===//
278+ // Pattern Population
279+ // ===----------------------------------------------------------------------===//
280+
49281void mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
50- /* TODO*/
282+ patterns.add <LoadStorePrefetchToOCLPattern<xevm::BlockLoad2dOp>,
283+ LoadStorePrefetchToOCLPattern<xevm::BlockStore2dOp>,
284+ LoadStorePrefetchToOCLPattern<xevm::BlockPrefetch2dOp>>(
285+ patterns.getContext ());
51286}
52287
288+ // ===----------------------------------------------------------------------===//
289+ // ConvertToLLVMPatternInterface implementation
290+ // ===----------------------------------------------------------------------===//
291+
292+ namespace {
293+ // / Implement the interface to convert XeVM to LLVM.
294+ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
295+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
296+ void loadDependentDialects (MLIRContext *context) const final {
297+ context->loadDialect <LLVM::LLVMDialect>();
298+ }
299+
300+ // / Hook for derived dialect interface to provide conversion patterns
301+ // / and mark dialect legal for the conversion target.
302+ void populateConvertToLLVMConversionPatterns (
303+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
304+ RewritePatternSet &patterns) const final {
305+ populateXeVMToLLVMConversionPatterns (patterns);
306+ }
307+ };
308+ } // namespace
309+
53310void mlir::registerConvertXeVMToLLVMInterface (DialectRegistry ®istry) {
54- /* TODO*/
311+ registry.addExtension (+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
312+ dialect->addInterfaces <XeVMToLLVMDialectInterface>();
313+ });
55314}
0 commit comments