Skip to content

Commit cd31485

Browse files
committed
replace CRTP with a simple trait in a common flow
1 parent d6b89f0 commit cd31485

File tree

2 files changed

+78
-142
lines changed

2 files changed

+78
-142
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 76 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,51 @@
1616

1717
namespace mlir {
1818

19-
template <typename SourceOp, typename DerivedTy>
20-
struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
19+
namespace {
20+
/// Detection trait tor the `getFastmath` instance method.
21+
template <typename T>
22+
using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
23+
} // namespace
24+
25+
/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26+
/// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27+
/// the fastMathFlag of that Op, if present. The function declaration is added
28+
/// in case it was not added before.
29+
///
30+
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
31+
/// value is first casted to f32, the function called and then the result casted
32+
/// back.
33+
///
34+
/// Example with NVVM:
35+
/// %exp_f32 = math.exp %arg_f32 : f32
36+
///
37+
/// will be transformed into
38+
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
39+
///
40+
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
41+
/// to the approximate calculation function.
42+
///
43+
/// Also example with NVVM:
44+
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
45+
///
46+
/// will be transformed into
47+
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48+
///
49+
/// Final example with NVVM:
50+
/// %pow_f32 = math.fpowi %arg_f32, %arg_i32
51+
///
52+
/// will be transformed into
53+
/// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
54+
template <typename SourceOp>
55+
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
2156
public:
22-
explicit OpToFuncCallLoweringBase(const LLVMTypeConverter &lowering)
23-
: ConvertOpToLLVMPattern<SourceOp>(lowering) {}
57+
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
58+
StringRef f32Func, StringRef f64Func,
59+
StringRef f32ApproxFunc, StringRef f16Func,
60+
StringRef i32Func = "")
61+
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
62+
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
63+
i32Func(i32Func) {}
2464

2565
LogicalResult
2666
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -46,15 +86,12 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
4686

4787
SmallVector<Value, 1> castedOperands;
4888
for (Value operand : adaptor.getOperands())
49-
castedOperands.push_back(
50-
static_cast<const DerivedTy *>(this)->maybeCast(operand, rewriter));
89+
castedOperands.push_back(maybeCast(operand, rewriter));
5190

5291
Type resultType = castedOperands.front().getType();
5392
Type funcType = getFunctionType(resultType, castedOperands);
54-
StringRef funcName =
55-
static_cast<const DerivedTy *>(this)
56-
->getFunctionName(
57-
cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
93+
StringRef funcName = getFunctionName(
94+
cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
5895
if (funcName.empty())
5996
return failure();
6097

@@ -67,14 +104,28 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
67104
return success();
68105
}
69106

107+
assert(callOp.getResult().getType().isF32() &&
108+
"only f32 types are supposed to be truncated back");
70109
Value truncated = rewriter.create<LLVM::FPTruncOp>(
71110
op->getLoc(), adaptor.getOperands().front().getType(),
72111
callOp.getResult());
73112
rewriter.replaceOp(op, {truncated});
74113
return success();
75114
}
76115

77-
private:
116+
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
117+
Type type = operand.getType();
118+
if (!isa<Float16Type, BFloat16Type>(type))
119+
return operand;
120+
121+
// if there's a f16 function, no need to cast f16 values
122+
if (!f16Func.empty() && isa<Float16Type>(type))
123+
return operand;
124+
125+
return rewriter.create<LLVM::FPExtOp>(
126+
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
127+
}
128+
78129
Type getFunctionType(Type resultType, ValueRange operands) const {
79130
SmallVector<Type> operandTypes(operands.getTypes());
80131
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
@@ -85,7 +136,8 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
85136
using LLVM::LLVMFuncOp;
86137

87138
auto funcAttr = StringAttr::get(op->getContext(), funcName);
88-
auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
139+
auto funcOp =
140+
SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
89141
if (funcOp)
90142
return funcOp;
91143

@@ -94,153 +146,37 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
94146
OpBuilder b(parentFunc);
95147
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
96148
}
97-
};
98-
99-
/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
100-
/// `f32ApproxFunc` or `f16Func` depending on the element type and the
101-
/// fastMathFlag of that Op. The function declaration is added in case it was
102-
/// not added before.
103-
///
104-
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
105-
/// value is first casted to f32, the function called and then the result casted
106-
/// back.
107-
///
108-
/// Example with NVVM:
109-
/// %exp_f32 = math.exp %arg_f32 : f32
110-
///
111-
/// will be transformed into
112-
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
113-
///
114-
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
115-
/// to the approximate calculation function.
116-
///
117-
/// Also example with NVVM:
118-
/// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
119-
///
120-
/// will be transformed into
121-
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
122-
template <typename SourceOp>
123-
struct OpToFuncCallLowering
124-
: public OpToFuncCallLoweringBase<SourceOp,
125-
OpToFuncCallLowering<SourceOp>> {
126-
public:
127-
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
128-
StringRef f32Func, StringRef f64Func,
129-
StringRef f32ApproxFunc, StringRef f16Func)
130-
: OpToFuncCallLoweringBase<SourceOp, OpToFuncCallLowering<SourceOp>>(
131-
lowering),
132-
f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc),
133-
f16Func(f16Func) {}
134-
135-
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
136-
Type type = operand.getType();
137-
if (!isa<Float16Type, BFloat16Type>(type))
138-
return operand;
139-
140-
// if there's a f16 function, no need to cast f16 values
141-
if (!f16Func.empty() && isa<Float16Type>(type))
142-
return operand;
143-
144-
return rewriter.create<LLVM::FPExtOp>(
145-
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
146-
}
147149

148150
StringRef getFunctionName(Type type, SourceOp op) const {
149-
arith::FastMathFlags flag = op.getFastmath();
151+
bool useApprox = false;
152+
if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
153+
arith::FastMathFlags flag = op.getFastmath();
154+
useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
155+
!f32ApproxFunc.empty();
156+
}
157+
150158
if (isa<Float16Type>(type))
151159
return f16Func;
152160
if (isa<Float32Type>(type)) {
153-
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
154-
!f32ApproxFunc.empty())
161+
if (useApprox)
155162
return f32ApproxFunc;
156-
else
157-
return f32Func;
163+
return f32Func;
158164
}
159165
if (isa<Float64Type>(type))
160166
return f64Func;
167+
168+
if (type.isInteger(32))
169+
return i32Func;
161170
return "";
162171
}
163172

164173
const std::string f32Func;
165174
const std::string f64Func;
166175
const std::string f32ApproxFunc;
167176
const std::string f16Func;
168-
};
169-
170-
/// Rewriting that replace SourceOp with a CallOp to `i32Func`
171-
/// The function declaration is added in case it was not added before.
172-
/// This assumes that all types integral.
173-
///
174-
/// Example with NVVM:
175-
/// %abs_i32 = math.iabs %arg_i32 : i32
176-
///
177-
/// will be transformed into
178-
/// llvm.call @__nv_abs(%arg_i32) : (i32) -> i32
179-
///
180-
template <typename SourceOp>
181-
struct IntOpToFuncCallLowering
182-
: public OpToFuncCallLoweringBase<SourceOp,
183-
IntOpToFuncCallLowering<SourceOp>> {
184-
public:
185-
explicit IntOpToFuncCallLowering(const LLVMTypeConverter &lowering,
186-
StringRef i32Func)
187-
: OpToFuncCallLoweringBase<SourceOp, IntOpToFuncCallLowering<SourceOp>>(
188-
lowering),
189-
i32Func(i32Func) {}
190-
191-
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
192-
return operand;
193-
}
194-
195-
StringRef getFunctionName(Type type, SourceOp op) const {
196-
IntegerType itype = dyn_cast<IntegerType>(type);
197-
if (!itype || itype.getWidth() != 32)
198-
return "";
199-
return i32Func;
200-
}
201-
202177
const std::string i32Func;
203178
};
204179

205-
/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`,
206-
/// depending on the type of the result. This assumes that the first argument is
207-
/// a floating type and the second argument is an integer type.
208-
///
209-
/// Example with NVVM:
210-
/// %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
211-
///
212-
/// will be transformed into
213-
/// llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32
214-
///
215-
template <typename SourceOp>
216-
struct FloatIntOpToFuncCallLowering
217-
: public OpToFuncCallLoweringBase<SourceOp,
218-
FloatIntOpToFuncCallLowering<SourceOp>> {
219-
public:
220-
explicit FloatIntOpToFuncCallLowering(const LLVMTypeConverter &lowering,
221-
StringRef f32Func, StringRef f64Func)
222-
: OpToFuncCallLoweringBase<SourceOp,
223-
FloatIntOpToFuncCallLowering<SourceOp>>(
224-
lowering),
225-
f32Func(f32Func), f64Func(f64Func) {}
226-
227-
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
228-
return operand;
229-
}
230-
231-
StringRef getFunctionName(Type type, SourceOp op) const {
232-
if (isa<Float32Type>(type)) {
233-
return f32Func;
234-
}
235-
if (isa<Float64Type>(type))
236-
return f64Func;
237-
return "";
238-
}
239-
240-
const std::string f32Func;
241-
const std::string f64Func;
242-
};
243-
244180
} // namespace mlir
245181

246182
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,15 +451,15 @@ static void populateIntOpPatterns(const LLVMTypeConverter &converter,
451451
RewritePatternSet &patterns,
452452
StringRef i32Func) {
453453
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
454-
patterns.add<IntOpToFuncCallLowering<OpTy>>(converter, i32Func);
454+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func);
455455
}
456456

457457
template <typename OpTy>
458458
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
459459
RewritePatternSet &patterns,
460460
StringRef f32Func, StringRef f64Func) {
461461
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
462-
patterns.add<FloatIntOpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
462+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "");
463463
}
464464

465465
void mlir::populateGpuSubgroupReduceOpLoweringPattern(

0 commit comments

Comments
 (0)