Skip to content

Commit 79015d5

Browse files
authored
Move utility to create function call into common utility file (#3338)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 6042515 commit 79015d5

File tree

4 files changed

+109
-87
lines changed

4 files changed

+109
-87
lines changed

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 25 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "Attributes.h"
10+
#include "Utils/LLVMIntr.h"
1011
#include "Utils/Mangling.h"
1112
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1213
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -52,71 +53,6 @@ using namespace mlir::triton::gpu;
5253
// Helper Functions
5354
//===----------------------------------------------------------------------===//
5455

55-
static intel::AttributeList createFunctionAttributes(
56-
ArrayRef<std::pair<llvm::Attribute::AttrKind, std::optional<uint64_t>>>
57-
attributes,
58-
MLIRContext *ctx) {
59-
intel::AttrBuilder funcAttrBuilder(*ctx);
60-
for (auto [kind, optValue] : attributes) {
61-
if (optValue)
62-
funcAttrBuilder.addPassthroughAttribute(kind, *optValue);
63-
else
64-
funcAttrBuilder.addPassthroughAttribute(kind);
65-
}
66-
67-
intel::AttributeList attrs;
68-
attrs.addFnAttributes(funcAttrBuilder);
69-
return attrs;
70-
}
71-
72-
struct LLVMFuncAttributeOptions {
73-
bool isConvergent = false;
74-
bool isNoUnwind = false;
75-
bool isWillReturn = false;
76-
LLVM::MemoryEffectsAttr memEffectsAttr{};
77-
};
78-
79-
static constexpr LLVMFuncAttributeOptions convergentAttrs = {
80-
true, false, false, {}};
81-
static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
82-
false, true, false, {}};
83-
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
84-
false, true, true, {}};
85-
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
86-
true, true, true, {}};
87-
88-
static LLVM::CallOp createDeviceFunctionCall(
89-
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
90-
ArrayRef<Type> argTypes, ArrayRef<Value> args,
91-
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
92-
const LLVMFuncAttributeOptions &funcAttributeOptions,
93-
const intel::AttributeList &passthroughAttrs = {}) {
94-
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
95-
MLIRContext *ctx = rewriter.getContext();
96-
Location loc = UnknownLoc::get(ctx);
97-
98-
LLVM::LLVMFuncOp funcOp =
99-
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType);
100-
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
101-
funcOp.setConvergent(funcAttributeOptions.isConvergent);
102-
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
103-
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
104-
105-
if (funcAttributeOptions.memEffectsAttr)
106-
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
107-
108-
for (auto [idx, attrName] : paramAttrs)
109-
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
110-
111-
if (!passthroughAttrs.getFnAttributes().empty())
112-
funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx));
113-
114-
auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
115-
callOp->setAttrs(funcOp->getAttrs());
116-
117-
return callOp;
118-
}
119-
12056
[[maybe_unused]] static std::string getGenISATypeMangling(Type ty) {
12157
if (auto vecTy = dyn_cast<VectorType>(ty))
12258
return "v" + std::to_string(vecTy.getNumElements()) +
@@ -230,8 +166,9 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
230166
b.i1_val(op.getVnniTransform()),
231167
b.i32_val(static_cast<int>(op.getCacheControl()))};
232168

233-
LLVM::CallOp call = createDeviceFunctionCall(
234-
rewriter, funcName, resType, argTypes, args, {}, noUnwindWillReturnAttrs);
169+
LLVM::CallOp call =
170+
intel::createDeviceFunctionCall(rewriter, funcName, resType, argTypes,
171+
args, {}, intel::noUnwindWillReturnAttrs);
235172
return call.getResult();
236173
}
237174

@@ -330,9 +267,9 @@ createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
330267
b.i32_val(static_cast<int>(op.getCacheControl())),
331268
storeVal};
332269

333-
LLVM::CallOp call =
334-
createDeviceFunctionCall(rewriter, funcName, void_ty(ctx), argTypes, args,
335-
{}, noUnwindWillReturnAttrs);
270+
LLVM::CallOp call = intel::createDeviceFunctionCall(
271+
rewriter, funcName, void_ty(ctx), argTypes, args, {},
272+
intel::noUnwindWillReturnAttrs);
336273
return call;
337274
}
338275

@@ -374,8 +311,9 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
374311
b.i32_val(static_cast<int>(op.getCacheControl()))};
375312

376313
const StringLiteral funcName = "llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid";
377-
return createDeviceFunctionCall(rewriter, funcName, void_ty(ctx), {argTypes},
378-
{args}, {}, noUnwindWillReturnAttrs);
314+
return intel::createDeviceFunctionCall(rewriter, funcName, void_ty(ctx),
315+
{argTypes}, {args}, {},
316+
intel::noUnwindWillReturnAttrs);
379317
}
380318

381319
namespace {
@@ -448,11 +386,11 @@ struct TritonMatrixDPASLowering
448386
/*other=*/LLVM::ModRefInfo::NoModRef,
449387
/*argMem=*/LLVM::ModRefInfo::NoModRef,
450388
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
451-
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
389+
auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs;
452390
funcAttrs.memEffectsAttr = memAttr;
453391

454-
Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
455-
args, {}, funcAttrs)
392+
Value result = intel::createDeviceFunctionCall(
393+
rewriter, fnName, cTy, argTypes, args, {}, funcAttrs)
456394
->getResult(0);
457395
if (cOrigTy != cTy)
458396
result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
@@ -524,9 +462,9 @@ struct TritonMatrix2DBlockLoadLowering
524462
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
525463
};
526464

527-
LLVM::CallOp call =
528-
createDeviceFunctionCall(rewriter, fnName, void_ty(ctx), argTypes, args,
529-
paramAttrs, noUnwindWillReturnAttrs);
465+
LLVM::CallOp call = intel::createDeviceFunctionCall(
466+
rewriter, fnName, void_ty(ctx), argTypes, args, paramAttrs,
467+
intel::noUnwindWillReturnAttrs);
530468
constexpr uint32_t ptrOperandIndex = 0;
531469
if (std::optional<TritonGEN::DecorationCacheControlAttr> optCacheControls =
532470
loadCacheControlToCacheControls(rewriter, op.getCacheControl(),
@@ -588,9 +526,9 @@ struct TritonMatrix2DBlockStoreLowering
588526
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
589527
};
590528

591-
LLVM::CallOp call =
592-
createDeviceFunctionCall(rewriter, fnName, void_ty(ctx), argTypes, args,
593-
paramAttrs, noUnwindWillReturnAttrs);
529+
LLVM::CallOp call = intel::createDeviceFunctionCall(
530+
rewriter, fnName, void_ty(ctx), argTypes, args, paramAttrs,
531+
intel::noUnwindWillReturnAttrs);
594532
constexpr uint32_t ptrOperandIndex = 0;
595533
if (std::optional<TritonGEN::DecorationCacheControlAttr> optCacheControls =
596534
storeCacheControlToCacheControls(rewriter, op.getCacheControl(),
@@ -638,10 +576,10 @@ struct TritonMatrix2DBlockPrefetchLowering
638576
/*other=*/LLVM::ModRefInfo::NoModRef,
639577
/*argMem=*/LLVM::ModRefInfo::Ref,
640578
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
641-
auto funcAttrs = noUnwindAttrs;
579+
auto funcAttrs = intel::noUnwindAttrs;
642580
funcAttrs.memEffectsAttr = memAttr;
643581

644-
LLVM::CallOp call = createDeviceFunctionCall(
582+
LLVM::CallOp call = intel::createDeviceFunctionCall(
645583
rewriter, fnName, void_ty(ctx), argTypes, args, paramAttrs, funcAttrs);
646584
constexpr uint32_t ptrOperandIndex = 0;
647585
if (std::optional<TritonGEN::DecorationCacheControlAttr> optCacheControls =
@@ -705,9 +643,9 @@ struct TritonSubGroupBlockReadLowering
705643
/*other=*/LLVM::ModRefInfo::NoModRef,
706644
/*argMem=*/LLVM::ModRefInfo::Ref,
707645
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
708-
auto funcAttrs = noUnwindWillReturnAttrs;
646+
auto funcAttrs = intel::noUnwindWillReturnAttrs;
709647
funcAttrs.memEffectsAttr = memAttr;
710-
LLVM::CallOp call = createDeviceFunctionCall(
648+
LLVM::CallOp call = intel::createDeviceFunctionCall(
711649
rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
712650

713651
rewriter.replaceOp(op, call.getResult());
@@ -733,9 +671,9 @@ struct TritonSubGroupBlockWriteLowering
733671
/*other=*/LLVM::ModRefInfo::NoModRef,
734672
/*argMem=*/LLVM::ModRefInfo::ModRef,
735673
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
736-
auto funcAttrs = noUnwindWillReturnAttrs;
674+
auto funcAttrs = intel::noUnwindWillReturnAttrs;
737675
funcAttrs.memEffectsAttr = memAttr;
738-
LLVM::CallOp call = createDeviceFunctionCall(
676+
LLVM::CallOp call = intel::createDeviceFunctionCall(
739677
rewriter, funcName, void_ty(ctx), {ptrTy, type},
740678
{op.getPtr(), op.getVal()}, {}, funcAttrs);
741679

third_party/intel/lib/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonIntelUtils
2+
LLVMIntr.cpp
23
Mangling.cpp
34

45
LINK_LIBS PUBLIC
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "LLVMIntr.h"
2+
3+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
4+
5+
namespace mlir::triton::gpu::intel {
6+
7+
LLVM::CallOp createDeviceFunctionCall(
8+
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
9+
ArrayRef<Type> argTypes, ArrayRef<Value> args,
10+
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
11+
const LLVMFuncAttributeOptions &funcAttributeOptions,
12+
const intel::AttributeList &passthroughAttrs, LLVM::cconv::CConv cc) {
13+
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
14+
MLIRContext *ctx = rewriter.getContext();
15+
Location loc = UnknownLoc::get(ctx);
16+
17+
LLVM::LLVMFuncOp funcOp =
18+
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType);
19+
funcOp.setCConv(cc);
20+
funcOp.setConvergent(funcAttributeOptions.isConvergent);
21+
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
22+
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
23+
24+
if (funcAttributeOptions.memEffectsAttr)
25+
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
26+
27+
for (auto [idx, attrName] : paramAttrs)
28+
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
29+
30+
if (!passthroughAttrs.getFnAttributes().empty())
31+
funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx));
32+
33+
auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
34+
callOp->setAttrs(funcOp->getAttrs());
35+
36+
return callOp;
37+
}
38+
39+
} // namespace mlir::triton::gpu::intel
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- LLVMIntr.h - Utilities to emit LLVM intrinsic calls ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef TRITON_INTEL_UTILS_LLVMINTR_H
10+
#define TRITON_INTEL_UTILS_LLVMINTR_H
11+
12+
#include "intel/lib/TritonGENToLLVM/Attributes.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/Transforms/DialectConversion.h"
15+
16+
#include <string>
17+
18+
namespace mlir::triton::gpu::intel {
19+
20+
struct LLVMFuncAttributeOptions {
21+
bool isConvergent = false;
22+
bool isNoUnwind = false;
23+
bool isWillReturn = false;
24+
LLVM::MemoryEffectsAttr memEffectsAttr{};
25+
};
26+
27+
constexpr LLVMFuncAttributeOptions convergentAttrs = {true, false, false, {}};
28+
constexpr LLVMFuncAttributeOptions noUnwindAttrs = {false, true, false, {}};
29+
constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
30+
false, true, true, {}};
31+
constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
32+
true, true, true, {}};
33+
34+
LLVM::CallOp createDeviceFunctionCall(
35+
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
36+
mlir::ArrayRef<mlir::Type> argTypes, ArrayRef<Value> args,
37+
ArrayRef<std::pair<unsigned, StringRef>> paramAttrs,
38+
const LLVMFuncAttributeOptions &funcAttributeOptions,
39+
const AttributeList &passthroughAttrs = {},
40+
LLVM::cconv::CConv cc = LLVM::cconv::CConv::SPIR_FUNC);
41+
42+
} // namespace mlir::triton::gpu::intel
43+
44+
#endif // TRITON_INTEL_UTILS_LLVMINTR_H

0 commit comments

Comments
 (0)