1616
1717namespace mlir {
1818
19- template <typename SourceOp, typename DerivedTy>
20- struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern <SourceOp> {
19+ namespace {
20+ // / Detection trait tor the `getFastmath` instance method.
21+ template <typename T>
22+ using has_get_fastmath_t = decltype (std::declval<T>().getFastmath());
23+ } // namespace
24+
25+ // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26+ // / `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27+ // / the fastMathFlag of that Op, if present. The function declaration is added
28+ // / in case it was not added before.
29+ // /
30+ // / If the input values are of bf16 type (or f16 type if f16Func is empty), the
31+ // / value is first casted to f32, the function called and then the result casted
32+ // / back.
33+ // /
34+ // / Example with NVVM:
35+ // / %exp_f32 = math.exp %arg_f32 : f32
36+ // /
37+ // / will be transformed into
38+ // / llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
39+ // /
40+ // / If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
41+ // / to the approximate calculation function.
42+ // /
43+ // / Also example with NVVM:
44+ // / %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
45+ // /
46+ // / will be transformed into
47+ // / llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48+ // /
49+ // / Final example with NVVM:
50+ // / %pow_f32 = math.fpowi %arg_f32, %arg_i32
51+ // /
52+ // / will be transformed into
53+ // / llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
54+ template <typename SourceOp>
55+ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern <SourceOp> {
2156public:
22- explicit OpToFuncCallLoweringBase (const LLVMTypeConverter &lowering)
23- : ConvertOpToLLVMPattern<SourceOp>(lowering) {}
57+ explicit OpToFuncCallLowering (const LLVMTypeConverter &lowering,
58+ StringRef f32Func, StringRef f64Func,
59+ StringRef f32ApproxFunc, StringRef f16Func,
60+ StringRef i32Func = " " )
61+ : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
62+ f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
63+ i32Func(i32Func) {}
2464
2565 LogicalResult
2666 matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -46,15 +86,12 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
4686
4787 SmallVector<Value, 1 > castedOperands;
4888 for (Value operand : adaptor.getOperands ())
49- castedOperands.push_back (
50- static_cast <const DerivedTy *>(this )->maybeCast (operand, rewriter));
89+ castedOperands.push_back (maybeCast (operand, rewriter));
5190
5291 Type resultType = castedOperands.front ().getType ();
5392 Type funcType = getFunctionType (resultType, castedOperands);
54- StringRef funcName =
55- static_cast <const DerivedTy *>(this )
56- ->getFunctionName (
57- cast<LLVM::LLVMFunctionType>(funcType).getReturnType (), op);
93+ StringRef funcName = getFunctionName (
94+ cast<LLVM::LLVMFunctionType>(funcType).getReturnType (), op);
5895 if (funcName.empty ())
5996 return failure ();
6097
@@ -67,14 +104,28 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
67104 return success ();
68105 }
69106
107+ assert (callOp.getResult ().getType ().isF32 () &&
108+ " only f32 types are supposed to be truncated back" );
70109 Value truncated = rewriter.create <LLVM::FPTruncOp>(
71110 op->getLoc (), adaptor.getOperands ().front ().getType (),
72111 callOp.getResult ());
73112 rewriter.replaceOp (op, {truncated});
74113 return success ();
75114 }
76115
77- private:
116+ Value maybeCast (Value operand, PatternRewriter &rewriter) const {
117+ Type type = operand.getType ();
118+ if (!isa<Float16Type, BFloat16Type>(type))
119+ return operand;
120+
121+ // if there's a f16 function, no need to cast f16 values
122+ if (!f16Func.empty () && isa<Float16Type>(type))
123+ return operand;
124+
125+ return rewriter.create <LLVM::FPExtOp>(
126+ operand.getLoc (), Float32Type::get (rewriter.getContext ()), operand);
127+ }
128+
78129 Type getFunctionType (Type resultType, ValueRange operands) const {
79130 SmallVector<Type> operandTypes (operands.getTypes ());
80131 return LLVM::LLVMFunctionType::get (resultType, operandTypes);
@@ -85,7 +136,8 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
85136 using LLVM::LLVMFuncOp;
86137
87138 auto funcAttr = StringAttr::get (op->getContext (), funcName);
88- auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
139+ auto funcOp =
140+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
89141 if (funcOp)
90142 return funcOp;
91143
@@ -94,153 +146,37 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
94146 OpBuilder b (parentFunc);
95147 return b.create <LLVMFuncOp>(op->getLoc (), funcName, funcType);
96148 }
97- };
98-
99- // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
100- // / `f32ApproxFunc` or `f16Func` depending on the element type and the
101- // / fastMathFlag of that Op. The function declaration is added in case it was
102- // / not added before.
103- // /
104- // / If the input values are of bf16 type (or f16 type if f16Func is empty), the
105- // / value is first casted to f32, the function called and then the result casted
106- // / back.
107- // /
108- // / Example with NVVM:
109- // / %exp_f32 = math.exp %arg_f32 : f32
110- // /
111- // / will be transformed into
112- // / llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
113- // /
114- // / If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
115- // / to the approximate calculation function.
116- // /
117- // / Also example with NVVM:
118- // / %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
119- // /
120- // / will be transformed into
121- // / llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
122- template <typename SourceOp>
123- struct OpToFuncCallLowering
124- : public OpToFuncCallLoweringBase<SourceOp,
125- OpToFuncCallLowering<SourceOp>> {
126- public:
127- explicit OpToFuncCallLowering (const LLVMTypeConverter &lowering,
128- StringRef f32Func, StringRef f64Func,
129- StringRef f32ApproxFunc, StringRef f16Func)
130- : OpToFuncCallLoweringBase<SourceOp, OpToFuncCallLowering<SourceOp>>(
131- lowering),
132- f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc),
133- f16Func(f16Func) {}
134-
135- Value maybeCast (Value operand, PatternRewriter &rewriter) const {
136- Type type = operand.getType ();
137- if (!isa<Float16Type, BFloat16Type>(type))
138- return operand;
139-
140- // if there's a f16 function, no need to cast f16 values
141- if (!f16Func.empty () && isa<Float16Type>(type))
142- return operand;
143-
144- return rewriter.create <LLVM::FPExtOp>(
145- operand.getLoc (), Float32Type::get (rewriter.getContext ()), operand);
146- }
147149
148150 StringRef getFunctionName (Type type, SourceOp op) const {
149- arith::FastMathFlags flag = op.getFastmath ();
151+ bool useApprox = false ;
152+ if constexpr (llvm::is_detected<has_get_fastmath_t , SourceOp>::value) {
153+ arith::FastMathFlags flag = op.getFastmath ();
154+ useApprox = ((uint32_t )arith::FastMathFlags::afn & (uint32_t )flag) &&
155+ !f32ApproxFunc.empty ();
156+ }
157+
150158 if (isa<Float16Type>(type))
151159 return f16Func;
152160 if (isa<Float32Type>(type)) {
153- if (((uint32_t )arith::FastMathFlags::afn & (uint32_t )flag) &&
154- !f32ApproxFunc.empty ())
161+ if (useApprox)
155162 return f32ApproxFunc;
156- else
157- return f32Func;
163+ return f32Func;
158164 }
159165 if (isa<Float64Type>(type))
160166 return f64Func;
167+
168+ if (type.isInteger (32 ))
169+ return i32Func;
161170 return " " ;
162171 }
163172
164173 const std::string f32Func;
165174 const std::string f64Func;
166175 const std::string f32ApproxFunc;
167176 const std::string f16Func;
168- };
169-
170- // / Rewriting that replace SourceOp with a CallOp to `i32Func`
171- // / The function declaration is added in case it was not added before.
172- // / This assumes that all types integral.
173- // /
174- // / Example with NVVM:
175- // / %abs_i32 = math.iabs %arg_i32 : i32
176- // /
177- // / will be transformed into
178- // / llvm.call @__nv_abs(%arg_i32) : (i32) -> i32
179- // /
180- template <typename SourceOp>
181- struct IntOpToFuncCallLowering
182- : public OpToFuncCallLoweringBase<SourceOp,
183- IntOpToFuncCallLowering<SourceOp>> {
184- public:
185- explicit IntOpToFuncCallLowering (const LLVMTypeConverter &lowering,
186- StringRef i32Func)
187- : OpToFuncCallLoweringBase<SourceOp, IntOpToFuncCallLowering<SourceOp>>(
188- lowering),
189- i32Func(i32Func) {}
190-
191- Value maybeCast (Value operand, PatternRewriter &rewriter) const {
192- return operand;
193- }
194-
195- StringRef getFunctionName (Type type, SourceOp op) const {
196- IntegerType itype = dyn_cast<IntegerType>(type);
197- if (!itype || itype.getWidth () != 32 )
198- return " " ;
199- return i32Func;
200- }
201-
202177 const std::string i32Func;
203178};
204179
205- // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`,
206- // / depending on the type of the result. This assumes that the first argument is
207- // / a floating type and the second argument is an integer type.
208- // /
209- // / Example with NVVM:
210- // / %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
211- // /
212- // / will be transformed into
213- // / llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32
214- // /
215- template <typename SourceOp>
216- struct FloatIntOpToFuncCallLowering
217- : public OpToFuncCallLoweringBase<SourceOp,
218- FloatIntOpToFuncCallLowering<SourceOp>> {
219- public:
220- explicit FloatIntOpToFuncCallLowering (const LLVMTypeConverter &lowering,
221- StringRef f32Func, StringRef f64Func)
222- : OpToFuncCallLoweringBase<SourceOp,
223- FloatIntOpToFuncCallLowering<SourceOp>>(
224- lowering),
225- f32Func(f32Func), f64Func(f64Func) {}
226-
227- Value maybeCast (Value operand, PatternRewriter &rewriter) const {
228- return operand;
229- }
230-
231- StringRef getFunctionName (Type type, SourceOp op) const {
232- if (isa<Float32Type>(type)) {
233- return f32Func;
234- }
235- if (isa<Float64Type>(type))
236- return f64Func;
237- return " " ;
238- }
239-
240- const std::string f32Func;
241- const std::string f64Func;
242- };
243-
244180} // namespace mlir
245181
246182#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
0 commit comments