Skip to content

Commit 7c849d7

Browse files
committed
[MLIR][Math][GPU] Add lowering of absi and fpowi to libdevice
1 parent 22d4ff1 commit 7c849d7

File tree

3 files changed

+181
-50
lines changed

3 files changed

+181
-50
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 138 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,37 +16,11 @@
1616

1717
namespace mlir {
1818

19-
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20-
/// `f32ApproxFunc` or `f16Func` depending on the element type and the
21-
/// fastMathFlag of that Op. The function declaration is added in case it was
22-
/// not added before.
23-
///
24-
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
25-
/// value is first casted to f32, the function called and then the result casted
26-
/// back.
27-
///
28-
/// Example with NVVM:
29-
/// %exp_f32 = math.exp %arg_f32 : f32
30-
///
31-
/// will be transformed into
32-
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
33-
///
34-
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
35-
/// to the approximate calculation function.
36-
///
37-
/// Also example with NVVM:
38-
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
39-
///
40-
/// will be transformed into
41-
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
42-
template <typename SourceOp>
43-
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
19+
template <typename SourceOp, typename DerivedTy>
20+
struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
4421
public:
45-
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
46-
StringRef f32Func, StringRef f64Func,
47-
StringRef f32ApproxFunc, StringRef f16Func)
48-
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
49-
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
22+
explicit OpToFuncCallLoweringBase(const LLVMTypeConverter &lowering)
23+
: ConvertOpToLLVMPattern<SourceOp>(lowering) {}
5024

5125
LogicalResult
5226
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -72,13 +46,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
7246

7347
SmallVector<Value, 1> castedOperands;
7448
for (Value operand : adaptor.getOperands())
75-
castedOperands.push_back(maybeCast(operand, rewriter));
49+
castedOperands.push_back(
50+
((const DerivedTy *)this)->maybeCast(operand, rewriter));
7651

7752
Type resultType = castedOperands.front().getType();
7853
Type funcType = getFunctionType(resultType, castedOperands);
7954
StringRef funcName =
80-
getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
81-
op.getFastmath());
55+
((const DerivedTy *)this)
56+
->getFunctionName(
57+
cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
8258
if (funcName.empty())
8359
return failure();
8460

@@ -99,6 +75,61 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
9975
}
10076

10177
private:
78+
Type getFunctionType(Type resultType, ValueRange operands) const {
79+
SmallVector<Type> operandTypes(operands.getTypes());
80+
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
81+
}
82+
83+
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
84+
Operation *op) const {
85+
using LLVM::LLVMFuncOp;
86+
87+
auto funcAttr = StringAttr::get(op->getContext(), funcName);
88+
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
89+
if (funcOp)
90+
return cast<LLVMFuncOp>(*funcOp);
91+
92+
mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
93+
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
94+
}
95+
};
96+
97+
/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
98+
/// `f32ApproxFunc` or `f16Func` depending on the element type and the
99+
/// fastMathFlag of that Op. The function declaration is added in case it was
100+
/// not added before.
101+
///
102+
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
103+
/// value is first casted to f32, the function called and then the result casted
104+
/// back.
105+
///
106+
/// Example with NVVM:
107+
/// %exp_f32 = math.exp %arg_f32 : f32
108+
///
109+
/// will be transformed into
110+
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
111+
///
112+
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
113+
/// to the approximate calculation function.
114+
///
115+
/// Also example with NVVM:
116+
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
117+
///
118+
/// will be transformed into
119+
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
120+
template <typename SourceOp>
121+
struct OpToFuncCallLowering
122+
: public OpToFuncCallLoweringBase<SourceOp,
123+
OpToFuncCallLowering<SourceOp>> {
124+
public:
125+
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
126+
StringRef f32Func, StringRef f64Func,
127+
StringRef f32ApproxFunc, StringRef f16Func)
128+
: OpToFuncCallLoweringBase<SourceOp, OpToFuncCallLowering<SourceOp>>(
129+
lowering),
130+
f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc),
131+
f16Func(f16Func) {}
132+
102133
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
103134
Type type = operand.getType();
104135
if (!isa<Float16Type, BFloat16Type>(type))
@@ -112,12 +143,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
112143
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
113144
}
114145

115-
Type getFunctionType(Type resultType, ValueRange operands) const {
116-
SmallVector<Type> operandTypes(operands.getTypes());
117-
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
118-
}
119-
120-
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
146+
StringRef getFunctionName(Type type, SourceOp op) const {
147+
arith::FastMathFlags flag = op.getFastmath();
121148
if (isa<Float16Type>(type))
122149
return f16Func;
123150
if (isa<Float32Type>(type)) {
@@ -132,23 +159,84 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
132159
return "";
133160
}
134161

135-
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
136-
Operation *op) const {
137-
using LLVM::LLVMFuncOp;
162+
const std::string f32Func;
163+
const std::string f64Func;
164+
const std::string f32ApproxFunc;
165+
const std::string f16Func;
166+
};
138167

139-
auto funcAttr = StringAttr::get(op->getContext(), funcName);
140-
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
141-
if (funcOp)
142-
return cast<LLVMFuncOp>(*funcOp);
168+
/// Rewriting that replace SourceOp with a CallOp to `i32Func`
169+
/// The function declaration is added in case it was not added before.
170+
/// This assumes that all types integral.
171+
///
172+
/// Example with NVVM:
173+
/// %abs_i32 = math.iabs %arg_i32 : i32
174+
///
175+
/// will be transformed into
176+
/// llvm.call @__nv_abs(%arg_i32) : (i32) -> i32
177+
///
178+
template <typename SourceOp>
179+
struct IntOpToFuncCallLowering
180+
: public OpToFuncCallLoweringBase<SourceOp,
181+
IntOpToFuncCallLowering<SourceOp>> {
182+
public:
183+
explicit IntOpToFuncCallLowering(const LLVMTypeConverter &lowering,
184+
StringRef i32Func)
185+
: OpToFuncCallLoweringBase<SourceOp, IntOpToFuncCallLowering<SourceOp>>(
186+
lowering),
187+
i32Func(i32Func) {}
143188

144-
mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
145-
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
189+
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
190+
return operand;
191+
}
192+
193+
StringRef getFunctionName(Type type, SourceOp op) const {
194+
IntegerType itype = dyn_cast<IntegerType>(type);
195+
if (!itype || itype.getWidth() != 32)
196+
return "";
197+
return i32Func;
198+
}
199+
200+
const std::string i32Func;
201+
};
202+
203+
/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`,
204+
/// depending on the type of the result. This assumes that the first argument is
205+
/// a floating type and the second argument is an integer type.
206+
///
207+
/// Example with NVVM:
208+
/// %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
209+
///
210+
/// will be transformed into
211+
/// llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32
212+
///
213+
template <typename SourceOp>
214+
struct FloatIntOpToFuncCallLowering
215+
: public OpToFuncCallLoweringBase<SourceOp,
216+
FloatIntOpToFuncCallLowering<SourceOp>> {
217+
public:
218+
explicit FloatIntOpToFuncCallLowering(const LLVMTypeConverter &lowering,
219+
StringRef f32Func, StringRef f64Func)
220+
: OpToFuncCallLoweringBase<SourceOp,
221+
FloatIntOpToFuncCallLowering<SourceOp>>(
222+
lowering),
223+
f32Func(f32Func), f64Func(f64Func) {}
224+
225+
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
226+
return operand;
227+
}
228+
229+
StringRef getFunctionName(Type type, SourceOp op) const {
230+
if (isa<Float32Type>(type)) {
231+
return f32Func;
232+
}
233+
if (isa<Float64Type>(type))
234+
return f64Func;
235+
return "";
146236
}
147237

148238
const std::string f32Func;
149239
const std::string f64Func;
150-
const std::string f32ApproxFunc;
151-
const std::string f16Func;
152240
};
153241

154242
} // namespace mlir

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,22 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
446446
f32ApproxFunc, f16Func);
447447
}
448448

449+
template <typename OpTy>
450+
static void populateIntOpPatterns(const LLVMTypeConverter &converter,
451+
RewritePatternSet &patterns,
452+
StringRef i32Func) {
453+
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
454+
patterns.add<IntOpToFuncCallLowering<OpTy>>(converter, i32Func);
455+
}
456+
457+
template <typename OpTy>
458+
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
459+
RewritePatternSet &patterns,
460+
StringRef f32Func, StringRef f64Func) {
461+
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
462+
patterns.add<FloatIntOpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
463+
}
464+
449465
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
450466
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
451467
patterns.add<GPUSubgroupReduceOpLowering>(converter);
@@ -509,6 +525,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
509525

510526
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
511527
"__nv_fmod");
528+
populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
512529
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
513530
"__nv_fabs");
514531
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
@@ -555,6 +572,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
555572
"__nv_log2", "__nv_fast_log2f");
556573
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
557574
"__nv_fast_powf");
575+
populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif",
576+
"__nv_powi");
558577
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
559578
"__nv_round");
560579
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,3 +1033,27 @@ module attributes {transform.with_named_sequence} {
10331033
transform.yield
10341034
}
10351035
}
1036+
1037+
1038+
gpu.module @test_module_52 {
1039+
// CHECK: llvm.func @__nv_abs(i32) -> i32
1040+
// CHECK-LABEL: func @gpu_abs
1041+
func.func @gpu_abs(%arg_i32 : i32) -> (i32) {
1042+
%result32 = math.absi %arg_i32 : i32
1043+
// CHECK: llvm.call @__nv_abs(%{{.*}}) : (i32) -> i32
1044+
func.return %result32 : i32
1045+
}
1046+
}
1047+
1048+
gpu.module @test_module_53 {
1049+
// CHECK: llvm.func @__nv_powif(f32, i32) -> f32
1050+
// CHECK: llvm.func @__nv_powi(f64, i32) -> f64
1051+
// CHECK-LABEL: func @gpu_powi
1052+
func.func @gpu_powi(%arg_f32 : f32, %arg_f64 : f64, %arg_i32 : i32) -> (f32, f64) {
1053+
%result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
1054+
// CHECK: llvm.call @__nv_powif(%{{.*}}, %{{.*}}) : (f32, i32) -> f32
1055+
%result64 = math.fpowi %arg_f64, %arg_i32 : f64, i32
1056+
// CHECK: llvm.call @__nv_powi(%{{.*}}, %{{.*}}) : (f64, i32) -> f64
1057+
func.return %result32, %result64 : f32, f64
1058+
}
1059+
}

0 commit comments

Comments
 (0)