77// ===----------------------------------------------------------------------===//
88
99#include " Attributes.h"
10+ #include " Utils/LLVMIntr.h"
1011#include " Utils/Mangling.h"
1112#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1213#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -52,71 +53,6 @@ using namespace mlir::triton::gpu;
5253// Helper Functions
5354// ===----------------------------------------------------------------------===//
5455
55- static intel::AttributeList createFunctionAttributes (
56- ArrayRef<std::pair<llvm::Attribute::AttrKind, std::optional<uint64_t >>>
57- attributes,
58- MLIRContext *ctx) {
59- intel::AttrBuilder funcAttrBuilder (*ctx);
60- for (auto [kind, optValue] : attributes) {
61- if (optValue)
62- funcAttrBuilder.addPassthroughAttribute (kind, *optValue);
63- else
64- funcAttrBuilder.addPassthroughAttribute (kind);
65- }
66-
67- intel::AttributeList attrs;
68- attrs.addFnAttributes (funcAttrBuilder);
69- return attrs;
70- }
71-
72- struct LLVMFuncAttributeOptions {
73- bool isConvergent = false ;
74- bool isNoUnwind = false ;
75- bool isWillReturn = false ;
76- LLVM::MemoryEffectsAttr memEffectsAttr{};
77- };
78-
79- static constexpr LLVMFuncAttributeOptions convergentAttrs = {
80- true , false , false , {}};
81- static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
82- false , true , false , {}};
83- static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
84- false , true , true , {}};
85- static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
86- true , true , true , {}};
87-
88- static LLVM::CallOp createDeviceFunctionCall (
89- ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
90- ArrayRef<Type> argTypes, ArrayRef<Value> args,
91- mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
92- const LLVMFuncAttributeOptions &funcAttributeOptions,
93- const intel::AttributeList &passthroughAttrs = {}) {
94- auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
95- MLIRContext *ctx = rewriter.getContext ();
96- Location loc = UnknownLoc::get (ctx);
97-
98- LLVM::LLVMFuncOp funcOp =
99- LLVM::lookupOrCreateFn (moduleOp, funcName, argTypes, retType);
100- funcOp.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
101- funcOp.setConvergent (funcAttributeOptions.isConvergent );
102- funcOp.setNoUnwind (funcAttributeOptions.isNoUnwind );
103- funcOp.setWillReturn (funcAttributeOptions.isWillReturn );
104-
105- if (funcAttributeOptions.memEffectsAttr )
106- funcOp.setMemoryEffectsAttr (funcAttributeOptions.memEffectsAttr );
107-
108- for (auto [idx, attrName] : paramAttrs)
109- funcOp.setArgAttr (idx, attrName, rewriter.getUnitAttr ());
110-
111- if (!passthroughAttrs.getFnAttributes ().empty ())
112- funcOp->setAttrs (passthroughAttrs.getFnAttributes ().getDictionary (ctx));
113-
114- auto callOp = rewriter.create <LLVM::CallOp>(loc, funcOp, args);
115- callOp->setAttrs (funcOp->getAttrs ());
116-
117- return callOp;
118- }
119-
12056[[maybe_unused]] static std::string getGenISATypeMangling (Type ty) {
12157 if (auto vecTy = dyn_cast<VectorType>(ty))
12258 return " v" + std::to_string (vecTy.getNumElements ()) +
@@ -230,8 +166,9 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
230166 b.i1_val (op.getVnniTransform ()),
231167 b.i32_val (static_cast <int >(op.getCacheControl ()))};
232168
233- LLVM::CallOp call = createDeviceFunctionCall (
234- rewriter, funcName, resType, argTypes, args, {}, noUnwindWillReturnAttrs);
169+ LLVM::CallOp call =
170+ intel::createDeviceFunctionCall (rewriter, funcName, resType, argTypes,
171+ args, {}, intel::noUnwindWillReturnAttrs);
235172 return call.getResult ();
236173}
237174
@@ -330,9 +267,9 @@ createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
330267 b.i32_val (static_cast <int >(op.getCacheControl ())),
331268 storeVal};
332269
333- LLVM::CallOp call =
334- createDeviceFunctionCall ( rewriter, funcName, void_ty (ctx), argTypes, args,
335- {}, noUnwindWillReturnAttrs);
270+ LLVM::CallOp call = intel::createDeviceFunctionCall (
271+ rewriter, funcName, void_ty (ctx), argTypes, args, {} ,
272+ intel:: noUnwindWillReturnAttrs);
336273 return call;
337274}
338275
@@ -374,8 +311,9 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
374311 b.i32_val (static_cast <int >(op.getCacheControl ()))};
375312
376313 const StringLiteral funcName = " llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid" ;
377- return createDeviceFunctionCall (rewriter, funcName, void_ty (ctx), {argTypes},
378- {args}, {}, noUnwindWillReturnAttrs);
314+ return intel::createDeviceFunctionCall (rewriter, funcName, void_ty (ctx),
315+ {argTypes}, {args}, {},
316+ intel::noUnwindWillReturnAttrs);
379317}
380318
381319namespace {
@@ -448,11 +386,11 @@ struct TritonMatrixDPASLowering
448386 /* other=*/ LLVM::ModRefInfo::NoModRef,
449387 /* argMem=*/ LLVM::ModRefInfo::NoModRef,
450388 /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
451- auto funcAttrs = convergentNoUnwindWillReturnAttrs;
389+ auto funcAttrs = intel:: convergentNoUnwindWillReturnAttrs;
452390 funcAttrs.memEffectsAttr = memAttr;
453391
454- Value result = createDeviceFunctionCall (rewriter, fnName, cTy, argTypes,
455- args, {}, funcAttrs)
392+ Value result = intel:: createDeviceFunctionCall (
393+ rewriter, fnName, cTy, argTypes, args, {}, funcAttrs)
456394 ->getResult (0 );
457395 if (cOrigTy != cTy)
458396 result = rewriter.create <LLVM::BitcastOp>(loc, cOrigTy, result);
@@ -524,9 +462,9 @@ struct TritonMatrix2DBlockLoadLowering
524462 std::make_pair (5 , LLVM::LLVMDialect::getWriteOnlyAttrName ()),
525463 };
526464
527- LLVM::CallOp call =
528- createDeviceFunctionCall ( rewriter, fnName, void_ty (ctx), argTypes, args,
529- paramAttrs, noUnwindWillReturnAttrs);
465+ LLVM::CallOp call = intel::createDeviceFunctionCall (
466+ rewriter, fnName, void_ty (ctx), argTypes, args, paramAttrs ,
467+ intel:: noUnwindWillReturnAttrs);
530468 constexpr uint32_t ptrOperandIndex = 0 ;
531469 if (std::optional<TritonGEN::DecorationCacheControlAttr> optCacheControls =
532470 loadCacheControlToCacheControls (rewriter, op.getCacheControl (),
@@ -588,9 +526,9 @@ struct TritonMatrix2DBlockStoreLowering
588526 std::make_pair (5 , LLVM::LLVMDialect::getReadonlyAttrName ()),
589527 };
590528
591- LLVM::CallOp call =
592- createDeviceFunctionCall ( rewriter, fnName, void_ty (ctx), argTypes, args,
593- paramAttrs, noUnwindWillReturnAttrs);
529+ LLVM::CallOp call = intel::createDeviceFunctionCall (
530+ rewriter, fnName, void_ty (ctx), argTypes, args, paramAttrs ,
531+ intel:: noUnwindWillReturnAttrs);
594532 constexpr uint32_t ptrOperandIndex = 0 ;
595533 if (std::optional<TritonGEN::DecorationCacheControlAttr> optCacheControls =
596534 storeCacheControlToCacheControls (rewriter, op.getCacheControl (),
@@ -638,10 +576,10 @@ struct TritonMatrix2DBlockPrefetchLowering
638576 /* other=*/ LLVM::ModRefInfo::NoModRef,
639577 /* argMem=*/ LLVM::ModRefInfo::Ref,
640578 /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
641- auto funcAttrs = noUnwindAttrs;
579+ auto funcAttrs = intel:: noUnwindAttrs;
642580 funcAttrs.memEffectsAttr = memAttr;
643581
644- LLVM::CallOp call = createDeviceFunctionCall (
582+ LLVM::CallOp call = intel:: createDeviceFunctionCall (
645583 rewriter, fnName, void_ty (ctx), argTypes, args, paramAttrs, funcAttrs);
646584 constexpr uint32_t ptrOperandIndex = 0 ;
647585 if (std::optional<TritonGEN::DecorationCacheControlAttr> optCacheControls =
@@ -705,9 +643,9 @@ struct TritonSubGroupBlockReadLowering
705643 /* other=*/ LLVM::ModRefInfo::NoModRef,
706644 /* argMem=*/ LLVM::ModRefInfo::Ref,
707645 /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
708- auto funcAttrs = noUnwindWillReturnAttrs;
646+ auto funcAttrs = intel:: noUnwindWillReturnAttrs;
709647 funcAttrs.memEffectsAttr = memAttr;
710- LLVM::CallOp call = createDeviceFunctionCall (
648+ LLVM::CallOp call = intel:: createDeviceFunctionCall (
711649 rewriter, funcName, type, {ptrTy}, {op.getPtr ()}, {}, funcAttrs, {});
712650
713651 rewriter.replaceOp (op, call.getResult ());
@@ -733,9 +671,9 @@ struct TritonSubGroupBlockWriteLowering
733671 /* other=*/ LLVM::ModRefInfo::NoModRef,
734672 /* argMem=*/ LLVM::ModRefInfo::ModRef,
735673 /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
736- auto funcAttrs = noUnwindWillReturnAttrs;
674+ auto funcAttrs = intel:: noUnwindWillReturnAttrs;
737675 funcAttrs.memEffectsAttr = memAttr;
738- LLVM::CallOp call = createDeviceFunctionCall (
676+ LLVM::CallOp call = intel:: createDeviceFunctionCall (
739677 rewriter, funcName, void_ty (ctx), {ptrTy, type},
740678 {op.getPtr (), op.getVal ()}, {}, funcAttrs);
741679
0 commit comments