1616
1717namespace mlir {
1818
19- // / Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20- // / `f32ApproxFunc` or `f16Func` depending on the element type and the
21- // / fastMathFlag of that Op. The function declaration is added in case it was
22- // / not added before.
23- // /
24- // / If the input values are of bf16 type (or f16 type if f16Func is empty), the
25- // / value is first casted to f32, the function called and then the result casted
26- // / back.
27- // /
28- // / Example with NVVM:
29- // / %exp_f32 = math.exp %arg_f32 : f32
30- // /
31- // / will be transformed into
32- // / llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
33- // /
34- // / If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
35- // / to the approximate calculation function.
36- // /
37- // / Also example with NVVM:
38- // / %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
39- // /
40- // / will be transformed into
41- // / llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
42- template <typename SourceOp>
43- struct OpToFuncCallLowering : public ConvertOpToLLVMPattern <SourceOp> {
19+ template <typename SourceOp, typename DerivedTy>
20+ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern <SourceOp> {
4421public:
45- explicit OpToFuncCallLowering (const LLVMTypeConverter &lowering,
46- StringRef f32Func, StringRef f64Func,
47- StringRef f32ApproxFunc, StringRef f16Func)
48- : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
49- f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
22+ explicit OpToFuncCallLoweringBase (const LLVMTypeConverter &lowering)
23+ : ConvertOpToLLVMPattern<SourceOp>(lowering) {}
5024
5125 LogicalResult
5226 matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -72,13 +46,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
7246
7347 SmallVector<Value, 1 > castedOperands;
7448 for (Value operand : adaptor.getOperands ())
75- castedOperands.push_back (maybeCast (operand, rewriter));
49+ castedOperands.push_back (
50+ ((const DerivedTy *)this )->maybeCast (operand, rewriter));
7651
7752 Type resultType = castedOperands.front ().getType ();
7853 Type funcType = getFunctionType (resultType, castedOperands);
7954 StringRef funcName =
80- getFunctionName (cast<LLVM::LLVMFunctionType>(funcType).getReturnType (),
81- op.getFastmath ());
55+ ((const DerivedTy *)this )
56+ ->getFunctionName (
57+ cast<LLVM::LLVMFunctionType>(funcType).getReturnType (), op);
8258 if (funcName.empty ())
8359 return failure ();
8460
@@ -99,6 +75,61 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
9975 }
10076
10177private:
78+ Type getFunctionType (Type resultType, ValueRange operands) const {
79+ SmallVector<Type> operandTypes (operands.getTypes ());
80+ return LLVM::LLVMFunctionType::get (resultType, operandTypes);
81+ }
82+
83+ LLVM::LLVMFuncOp appendOrGetFuncOp (StringRef funcName, Type funcType,
84+ Operation *op) const {
85+ using LLVM::LLVMFuncOp;
86+
87+ auto funcAttr = StringAttr::get (op->getContext (), funcName);
88+ Operation *funcOp = SymbolTable::lookupNearestSymbolFrom (op, funcAttr);
89+ if (funcOp)
90+ return cast<LLVMFuncOp>(*funcOp);
91+
92+ mlir::OpBuilder b (op->getParentOfType <FunctionOpInterface>());
93+ return b.create <LLVMFuncOp>(op->getLoc (), funcName, funcType);
94+ }
95+ };
96+
97+ // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
98+ // / `f32ApproxFunc` or `f16Func` depending on the element type and the
99+ // / fastMathFlag of that Op. The function declaration is added in case it was
100+ // / not added before.
101+ // /
102+ // / If the input values are of bf16 type (or f16 type if f16Func is empty), the
103+ // / value is first casted to f32, the function called and then the result casted
104+ // / back.
105+ // /
106+ // / Example with NVVM:
107+ // / %exp_f32 = math.exp %arg_f32 : f32
108+ // /
109+ // / will be transformed into
110+ // / llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
111+ // /
112+ // / If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
113+ // / to the approximate calculation function.
114+ // /
115+ // / Also example with NVVM:
116+ // / %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
117+ // /
118+ // / will be transformed into
119+ // / llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
120+ template <typename SourceOp>
121+ struct OpToFuncCallLowering
122+ : public OpToFuncCallLoweringBase<SourceOp,
123+ OpToFuncCallLowering<SourceOp>> {
124+ public:
125+ explicit OpToFuncCallLowering (const LLVMTypeConverter &lowering,
126+ StringRef f32Func, StringRef f64Func,
127+ StringRef f32ApproxFunc, StringRef f16Func)
128+ : OpToFuncCallLoweringBase<SourceOp, OpToFuncCallLowering<SourceOp>>(
129+ lowering),
130+ f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc),
131+ f16Func(f16Func) {}
132+
102133 Value maybeCast (Value operand, PatternRewriter &rewriter) const {
103134 Type type = operand.getType ();
104135 if (!isa<Float16Type, BFloat16Type>(type))
@@ -112,12 +143,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
112143 operand.getLoc (), Float32Type::get (rewriter.getContext ()), operand);
113144 }
114145
115- Type getFunctionType (Type resultType, ValueRange operands) const {
116- SmallVector<Type> operandTypes (operands.getTypes ());
117- return LLVM::LLVMFunctionType::get (resultType, operandTypes);
118- }
119-
120- StringRef getFunctionName (Type type, arith::FastMathFlags flag) const {
146+ StringRef getFunctionName (Type type, SourceOp op) const {
147+ arith::FastMathFlags flag = op.getFastmath ();
121148 if (isa<Float16Type>(type))
122149 return f16Func;
123150 if (isa<Float32Type>(type)) {
@@ -132,23 +159,84 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
132159 return " " ;
133160 }
134161
135- LLVM::LLVMFuncOp appendOrGetFuncOp (StringRef funcName, Type funcType,
136- Operation *op) const {
137- using LLVM::LLVMFuncOp;
162+ const std::string f32Func;
163+ const std::string f64Func;
164+ const std::string f32ApproxFunc;
165+ const std::string f16Func;
166+ };
138167
139- auto funcAttr = StringAttr::get (op->getContext (), funcName);
140- Operation *funcOp = SymbolTable::lookupNearestSymbolFrom (op, funcAttr);
141- if (funcOp)
142- return cast<LLVMFuncOp>(*funcOp);
168+ // / Rewriting that replace SourceOp with a CallOp to `i32Func`
169+ // / The function declaration is added in case it was not added before.
170+ // / This assumes that all types integral.
171+ // /
172+ // / Example with NVVM:
173+ // / %abs_i32 = math.iabs %arg_i32 : i32
174+ // /
175+ // / will be transformed into
176+ // / llvm.call @__nv_abs(%arg_i32) : (i32) -> i32
177+ // /
178+ template <typename SourceOp>
179+ struct IntOpToFuncCallLowering
180+ : public OpToFuncCallLoweringBase<SourceOp,
181+ IntOpToFuncCallLowering<SourceOp>> {
182+ public:
183+ explicit IntOpToFuncCallLowering (const LLVMTypeConverter &lowering,
184+ StringRef i32Func)
185+ : OpToFuncCallLoweringBase<SourceOp, IntOpToFuncCallLowering<SourceOp>>(
186+ lowering),
187+ i32Func(i32Func) {}
143188
144- mlir::OpBuilder b (op->getParentOfType <FunctionOpInterface>());
145- return b.create <LLVMFuncOp>(op->getLoc (), funcName, funcType);
189+ Value maybeCast (Value operand, PatternRewriter &rewriter) const {
190+ return operand;
191+ }
192+
193+ StringRef getFunctionName (Type type, SourceOp op) const {
194+ IntegerType itype = dyn_cast<IntegerType>(type);
195+ if (!itype || itype.getWidth () != 32 )
196+ return " " ;
197+ return i32Func;
198+ }
199+
200+ const std::string i32Func;
201+ };
202+
203+ // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`,
204+ // / depending on the type of the result. This assumes that the first argument is
205+ // / a floating type and the second argument is an integer type.
206+ // /
207+ // / Example with NVVM:
208+ // / %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
209+ // /
210+ // / will be transformed into
211+ // / llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32
212+ // /
213+ template <typename SourceOp>
214+ struct FloatIntOpToFuncCallLowering
215+ : public OpToFuncCallLoweringBase<SourceOp,
216+ FloatIntOpToFuncCallLowering<SourceOp>> {
217+ public:
218+ explicit FloatIntOpToFuncCallLowering (const LLVMTypeConverter &lowering,
219+ StringRef f32Func, StringRef f64Func)
220+ : OpToFuncCallLoweringBase<SourceOp,
221+ FloatIntOpToFuncCallLowering<SourceOp>>(
222+ lowering),
223+ f32Func(f32Func), f64Func(f64Func) {}
224+
225+ Value maybeCast (Value operand, PatternRewriter &rewriter) const {
226+ return operand;
227+ }
228+
229+ StringRef getFunctionName (Type type, SourceOp op) const {
230+ if (isa<Float32Type>(type)) {
231+ return f32Func;
232+ }
233+ if (isa<Float64Type>(type))
234+ return f64Func;
235+ return " " ;
146236 }
147237
148238 const std::string f32Func;
149239 const std::string f64Func;
150- const std::string f32ApproxFunc;
151- const std::string f16Func;
152240};
153241
154242} // namespace mlir
0 commit comments