1111#include " gc/Dialect/LLVMIR/XeVMDialect.h"
1212#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313#include " mlir/Conversion/LLVMCommon/Pattern.h"
14+ #include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1415#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1516#include " mlir/Pass/Pass.h"
1617#include " mlir/Support/LLVM.h"
18+ #include " llvm/Support/FormatVariadic.h"
19+
20+ #include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/Types.h"
22+
23+ #include " llvm/ADT/STLExtras.h"
24+ #include " llvm/ADT/TypeSwitch.h"
25+ #include " llvm/Support/raw_ostream.h"
1726
1827#define DEBUG_TYPE " xevm-to-llvm"
1928
@@ -26,6 +35,231 @@ using namespace mlir;
2635using namespace xevm ;
2736
2837namespace {
38+ struct LLVMFuncAttributeOptions {
39+ bool isConvergent = false ;
40+ bool isNoUnwind = false ;
41+ bool isWillReturn = false ;
42+ LLVM::MemoryEffectsAttr memEffectsAttr{};
43+ };
44+ // static constexpr LLVMFuncAttributeOptions convergentAttrs = {
45+ // true, false, false, {}};
46+ // static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
47+ // false, true, false, {}};
48+ static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
49+ false , true , true , {}};
50+ // static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs =
51+ // {
52+ // true, true, true, {}};
53+
54+ std::string getTypeMangling (Type ty, bool isUnsigned = false ) {
55+ return TypeSwitch<Type, std::string>(ty)
56+ .Case ([isUnsigned](VectorType ty) -> std::string {
57+ return " Dv" + std::to_string (ty.getNumElements ()) + " _" +
58+ getTypeMangling (ty.getElementType (), isUnsigned);
59+ })
60+ .Case ([](Float16Type) -> std::string { return " Dh" ; })
61+ .Case ([](Float32Type) -> std::string { return " f" ; })
62+ .Case ([](Float64Type) -> std::string { return " d" ; })
63+ .Case ([isUnsigned](IntegerType ty) -> std::string {
64+ switch (ty.getWidth ()) {
65+ case 8 :
66+ return isUnsigned ? " h" : " c" ;
67+ case 16 :
68+ return isUnsigned ? " t" : " s" ;
69+ case 32 :
70+ return isUnsigned ? " j" : " i" ;
71+ case 64 :
72+ return isUnsigned ? " m" : " l" ;
73+ default :
74+ llvm_unreachable (" unhandled integer type" );
75+ }
76+ });
77+ }
78+
79+ template <typename OpType>
80+ static std::optional<ArrayAttr>
81+ getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op,
82+ const bool isLoad) {
83+ if ((op.getL1CacheControlAttr () ==
84+ xevm::L1StoreCacheControlAttr::get (
85+ rewriter.getContext (), xevm::L1StoreCacheControl::DEFAULT) &&
86+ op.getL3CacheControlAttr () ==
87+ xevm::L3StoreCacheControlAttr::get (
88+ rewriter.getContext (), xevm::L3StoreCacheControl::DEFAULT)) ||
89+
90+ (op.getL1CacheControlAttr () ==
91+ xevm::L1LoadCacheControlAttr::get (
92+ rewriter.getContext (), xevm::L1LoadCacheControl::DEFAULT) &&
93+ op.getL3CacheControlAttr () ==
94+ xevm::L3LoadCacheControlAttr::get (
95+ rewriter.getContext (), xevm::L3LoadCacheControl::DEFAULT))) {
96+ return {};
97+ }
98+ constexpr int32_t decorationCacheControlArity{4 };
99+ constexpr int32_t loadCacheControlKey{6442 };
100+ constexpr int32_t storeCacheControlKey{6443 };
101+ constexpr int32_t l1Level{0 };
102+ constexpr int32_t l3Level{1 };
103+ const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
104+ SmallVector<int32_t , decorationCacheControlArity> decorationsL1{
105+ controlKey, l1Level, static_cast <int32_t >(op.getL1CacheControl ()), 0 };
106+ SmallVector<int32_t , decorationCacheControlArity> decorationsL3{
107+ controlKey, l3Level, static_cast <int32_t >(op.getL3CacheControl ()), 0 };
108+ auto arrayAttrL1 = rewriter.getI32ArrayAttr (decorationsL1);
109+ auto arrayAttrL3 = rewriter.getI32ArrayAttr (decorationsL3);
110+
111+ SmallVector<Attribute, 2 > combinedAttrs = {arrayAttrL1, arrayAttrL3};
112+ return rewriter.getArrayAttr (combinedAttrs);
113+ }
114+
115+ static LLVM::CallOp createDeviceFunctionCall (
116+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
117+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
118+ mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
119+ LLVMFuncAttributeOptions funcAttributeOptions) {
120+ auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
121+ MLIRContext *ctx = rewriter.getContext ();
122+ Location loc = UnknownLoc::get (ctx);
123+
124+ LLVM::LLVMFuncOp funcOp =
125+ LLVM::lookupOrCreateFn (moduleOp, funcName, argTypes, retType);
126+ funcOp.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
127+ funcOp.setConvergent (funcAttributeOptions.isConvergent );
128+ funcOp.setNoUnwind (funcAttributeOptions.isNoUnwind );
129+ funcOp.setWillReturn (funcAttributeOptions.isWillReturn );
130+
131+ if (funcAttributeOptions.memEffectsAttr )
132+ funcOp.setMemoryEffectsAttr (funcAttributeOptions.memEffectsAttr );
133+
134+ for (auto [idx, attrName] : paramAttrs)
135+ funcOp.setArgAttr (idx, attrName, rewriter.getUnitAttr ());
136+
137+ // if (!passthroughAttrs.getFnAttributes().empty())
138+ // funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx));
139+
140+ auto callOp = rewriter.create <LLVM::CallOp>(loc, funcOp, args);
141+ callOp->setAttrs (funcOp->getAttrs ());
142+
143+ return callOp;
144+ }
145+
146+ template <typename OpType>
147+ class LoadStorePrefetchToOCLPattern : public OpConversionPattern <OpType> {
148+ using OpConversionPattern<OpType>::OpConversionPattern;
149+ LogicalResult
150+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
151+ ConversionPatternRewriter &rewriter) const override {
152+ constexpr bool isLoad = std::is_same_v<OpType, xevm::BlockLoad2dOp>;
153+ constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStore2dOp>;
154+ constexpr bool isPrefetch = std::is_same_v<OpType, xevm::BlockPrefetch2dOp>;
155+ auto loc = op.getLoc ();
156+ VectorType vecType;
157+ if constexpr (isLoad) {
158+ vecType = op.getRes ().getType ();
159+ } else if constexpr (isStore) {
160+ vecType = op.getStoredVal ().getType ();
161+ }
162+
163+ auto i32Type = rewriter.getI32Type ();
164+ bool vnni = false ;
165+ bool transpose = false ;
166+ if constexpr (isLoad) {
167+ vnni = op.getVnniTransform ();
168+ transpose = op.getTranspose ();
169+ }
170+
171+ Value byteCoord =
172+ rewriter.create <LLVM::UndefOp>(loc, VectorType::get (2 , i32Type));
173+ Value zero = rewriter.create <LLVM::ConstantOp>(
174+ loc, i32Type, rewriter.getI32IntegerAttr (0 ));
175+ Value one = rewriter.create <LLVM::ConstantOp>(
176+ loc, i32Type, rewriter.getI32IntegerAttr (1 ));
177+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
178+ loc, VectorType::get (2 , i32Type), byteCoord, op.getX (), zero);
179+ byteCoord = rewriter.create <LLVM::InsertElementOp>(
180+ loc, VectorType::get (2 , i32Type), byteCoord, op.getY (), one);
181+ SmallVector<Value> args{op.getPtr (), op.getBaseWidth (), op.getBaseHeight (),
182+ op.getBasePitch (), byteCoord};
183+ SmallVector<Type> retTypes;
184+ Value spvLoadDstPtr;
185+ std::string funcName, bitWidthId;
186+ SmallVector<std::pair<unsigned , mlir::StringRef>, 4 > paramAttrs;
187+ if constexpr (isPrefetch) { // Prefetch
188+ funcName = " intel_sub_group_2d_block_prefetch" ;
189+ paramAttrs = {std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ())};
190+ } else {
191+ auto vecElemType = vecType.getElementType ();
192+ auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth ();
193+ Value numElems = rewriter.create <LLVM::ConstantOp>(
194+ loc, i32Type, vecType.getNumElements ());
195+ auto dstOrSrcPtr = rewriter.create <LLVM::AllocaOp>(
196+ loc, LLVM::LLVMPointerType::get (rewriter.getContext ()), vecElemType,
197+ numElems);
198+ args.push_back (dstOrSrcPtr);
199+ if constexpr (isLoad) { // Load
200+ funcName = " intel_sub_group_2d_block_read" ;
201+ bitWidthId = getTypeMangling (vecElemType, /* isUnsigned=*/ true );
202+ if (vnni)
203+ funcName += " _transform" ;
204+ else if (transpose)
205+ funcName += " _transpose" ;
206+ spvLoadDstPtr = dstOrSrcPtr;
207+ retTypes.push_back (vecType);
208+ paramAttrs = {
209+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
210+ std::make_pair (0 , LLVM::LLVMDialect::getReadonlyAttrName ()),
211+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
212+ std::make_pair (5 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
213+ };
214+ } else { // Store
215+ funcName = " intel_sub_group_2d_block_write" ;
216+ bitWidthId = (vecElemBitWidth == 32 )
217+ ? " j"
218+ : ((vecElemBitWidth == 16 ) ? " t" : " h" );
219+ rewriter.create <LLVM::StoreOp>(loc, op.getStoredVal (), dstOrSrcPtr);
220+ paramAttrs = {
221+ std::make_pair (0 , LLVM::LLVMDialect::getNonNullAttrName ()),
222+ std::make_pair (0 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
223+ std::make_pair (5 , LLVM::LLVMDialect::getNonNullAttrName ()),
224+ std::make_pair (5 , LLVM::LLVMDialect::getReadonlyAttrName ()),
225+ };
226+ }
227+ }
228+
229+ // !X = !{i32 %decoration_kind%, i32 %level%, i32 %control%, i32 %operand of
230+ // the instruction to decorate%}
231+ funcName =
232+ llvm::formatv (" {0}_{1}b_{2}r{3}x{4}c" , funcName, op.getElemSizeInBits (),
233+ op.getTileHeight (), op.getTileWidth (), op.getVBlocks ())
234+ .str ();
235+ funcName = llvm::formatv (" _Z{0}{1}PU3AS1viiiDv2_i{2}{3}" , funcName.size (),
236+ funcName, isPrefetch ? " " : " P" , bitWidthId)
237+ .str ();
238+ SmallVector<Type> argTypes;
239+ for (auto arg : args) {
240+ argTypes.push_back (arg.getType ());
241+ }
242+ LLVM::CallOp call = createDeviceFunctionCall (
243+ rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
244+ argTypes, args, paramAttrs, noUnwindWillReturnAttrs);
245+ if (std::optional<ArrayAttr> optCacheControls =
246+ getCacheControlMetadata (rewriter, op, isLoad || isPrefetch)) {
247+ call->setAttr (xevm::XeVMDialect::getCacheControlsAttrName (),
248+ *optCacheControls);
249+ }
250+ if constexpr (isLoad)
251+ rewriter.replaceOp (
252+ op, rewriter.create <LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
253+ else
254+ rewriter.eraseOp (op);
255+ return success ();
256+ }
257+ };
258+
259+ // ===----------------------------------------------------------------------===//
260+ // Pass Definition
261+ // ===----------------------------------------------------------------------===//
262+
29263struct ConvertXeVMToLLVMPass
30264 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
31265 using Base::Base;
@@ -37,19 +271,51 @@ struct ConvertXeVMToLLVMPass
37271 void runOnOperation () override {
38272 ConversionTarget target (getContext ());
39273 target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
40- RewritePatternSet pattern (&getContext ());
41- mlir::populateXeVMToLLVMConversionPatterns (pattern);
42- if (failed (
43- applyPartialConversion (getOperation (), target, std::move (pattern))))
274+ target.addIllegalDialect <xevm::XeVMDialect>();
275+ RewritePatternSet patterns (&getContext ());
276+ mlir::populateXeVMToLLVMConversionPatterns (patterns);
277+ if (failed (applyPartialConversion (getOperation (), target,
278+ std::move (patterns))))
44279 signalPassFailure ();
45280 }
46281};
47282} // namespace
48283
284+ // ===----------------------------------------------------------------------===//
285+ // Pattern Population
286+ // ===----------------------------------------------------------------------===//
287+
49288void mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
50- /* TODO*/
289+ patterns.add <LoadStorePrefetchToOCLPattern<xevm::BlockLoad2dOp>,
290+ LoadStorePrefetchToOCLPattern<xevm::BlockStore2dOp>,
291+ LoadStorePrefetchToOCLPattern<xevm::BlockPrefetch2dOp>>(
292+ patterns.getContext ());
51293}
52294
295+ // ===----------------------------------------------------------------------===//
296+ // ConvertToLLVMPatternInterface implementation
297+ // ===----------------------------------------------------------------------===//
298+
299+ namespace {
300+ // / Implement the interface to convert XeVM to LLVM.
301+ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
302+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
303+ void loadDependentDialects (MLIRContext *context) const final {
304+ context->loadDialect <LLVM::LLVMDialect>();
305+ }
306+
307+ // / Hook for derived dialect interface to provide conversion patterns
308+ // / and mark dialect legal for the conversion target.
309+ void populateConvertToLLVMConversionPatterns (
310+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
311+ RewritePatternSet &patterns) const final {
312+ populateXeVMToLLVMConversionPatterns (patterns);
313+ }
314+ };
315+ } // namespace
316+
53317void mlir::registerConvertXeVMToLLVMInterface (DialectRegistry ®istry) {
54- /* TODO*/
318+ registry.addExtension (+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
319+ dialect->addInterfaces <XeVMToLLVMDialectInterface>();
320+ });
55321}
0 commit comments