1616
1717namespace mlir {
1818
19- // / Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20- // / `f32ApproxFunc` or `f16Func` depending on the element type and the
21- // / fastMathFlag of that Op. The function declaration is added in case it was
22- // / not added before.
19+ namespace {
20+ // / Detection trait tor the `getFastmath` instance method.
21+ template <typename T>
22+ using has_get_fastmath_t = decltype (std::declval<T>().getFastmath());
23+ } // namespace
24+
25+ // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26+ // / `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27+ // / the fastMathFlag of that Op, if present. The function declaration is added
28+ // / in case it was not added before.
2329// /
2430// / If the input values are of bf16 type (or f16 type if f16Func is empty), the
2531// / value is first casted to f32, the function called and then the result casted
@@ -39,14 +45,22 @@ namespace mlir {
3945// /
4046// / will be transformed into
4147// / llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48+ // /
49+ // / Final example with NVVM:
50+ // / %pow_f32 = math.fpowi %arg_f32, %arg_i32
51+ // /
52+ // / will be transformed into
53+ // / llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
4254template <typename SourceOp>
4355struct OpToFuncCallLowering : public ConvertOpToLLVMPattern <SourceOp> {
4456public:
4557 explicit OpToFuncCallLowering (const LLVMTypeConverter &lowering,
4658 StringRef f32Func, StringRef f64Func,
47- StringRef f32ApproxFunc, StringRef f16Func)
59+ StringRef f32ApproxFunc, StringRef f16Func,
60+ StringRef i32Func = " " )
4861 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
49- f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
62+ f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
63+ i32Func(i32Func) {}
5064
5165 LogicalResult
5266 matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -76,9 +90,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
7690
7791 Type resultType = castedOperands.front ().getType ();
7892 Type funcType = getFunctionType (resultType, castedOperands);
79- StringRef funcName =
80- getFunctionName (cast<LLVM::LLVMFunctionType>(funcType).getReturnType (),
81- op.getFastmath ());
93+ StringRef funcName = getFunctionName (
94+ cast<LLVM::LLVMFunctionType>(funcType).getReturnType (), op);
8295 if (funcName.empty ())
8396 return failure ();
8497
@@ -91,14 +104,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
91104 return success ();
92105 }
93106
107+ assert (callOp.getResult ().getType ().isF32 () &&
108+ " only f32 types are supposed to be truncated back" );
94109 Value truncated = rewriter.create <LLVM::FPTruncOp>(
95110 op->getLoc (), adaptor.getOperands ().front ().getType (),
96111 callOp.getResult ());
97112 rewriter.replaceOp (op, {truncated});
98113 return success ();
99114 }
100115
101- private:
102116 Value maybeCast (Value operand, PatternRewriter &rewriter) const {
103117 Type type = operand.getType ();
104118 if (!isa<Float16Type, BFloat16Type>(type))
@@ -117,38 +131,50 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
117131 return LLVM::LLVMFunctionType::get (resultType, operandTypes);
118132 }
119133
120- StringRef getFunctionName (Type type, arith::FastMathFlags flag) const {
121- if (isa<Float16Type>(type))
122- return f16Func;
123- if (isa<Float32Type>(type)) {
124- if (((uint32_t )arith::FastMathFlags::afn & (uint32_t )flag) &&
125- !f32ApproxFunc.empty ())
126- return f32ApproxFunc;
127- else
128- return f32Func;
129- }
130- if (isa<Float64Type>(type))
131- return f64Func;
132- return " " ;
133- }
134-
135134 LLVM::LLVMFuncOp appendOrGetFuncOp (StringRef funcName, Type funcType,
136135 Operation *op) const {
137136 using LLVM::LLVMFuncOp;
138137
139138 auto funcAttr = StringAttr::get (op->getContext (), funcName);
140- Operation *funcOp = SymbolTable::lookupNearestSymbolFrom (op, funcAttr);
139+ auto funcOp =
140+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
141141 if (funcOp)
142- return cast<LLVMFuncOp>(* funcOp) ;
142+ return funcOp;
143143
144- mlir::OpBuilder b (op->getParentOfType <FunctionOpInterface>());
144+ auto parentFunc = op->getParentOfType <FunctionOpInterface>();
145+ assert (parentFunc && " expected there to be a parent function" );
146+ OpBuilder b (parentFunc);
145147 return b.create <LLVMFuncOp>(op->getLoc (), funcName, funcType);
146148 }
147149
150+ StringRef getFunctionName (Type type, SourceOp op) const {
151+ bool useApprox = false ;
152+ if constexpr (llvm::is_detected<has_get_fastmath_t , SourceOp>::value) {
153+ arith::FastMathFlags flag = op.getFastmath ();
154+ useApprox = ((uint32_t )arith::FastMathFlags::afn & (uint32_t )flag) &&
155+ !f32ApproxFunc.empty ();
156+ }
157+
158+ if (isa<Float16Type>(type))
159+ return f16Func;
160+ if (isa<Float32Type>(type)) {
161+ if (useApprox)
162+ return f32ApproxFunc;
163+ return f32Func;
164+ }
165+ if (isa<Float64Type>(type))
166+ return f64Func;
167+
168+ if (type.isInteger (32 ))
169+ return i32Func;
170+ return " " ;
171+ }
172+
148173 const std::string f32Func;
149174 const std::string f64Func;
150175 const std::string f32ApproxFunc;
151176 const std::string f16Func;
177+ const std::string i32Func;
152178};
153179
154180} // namespace mlir
0 commit comments