Skip to content

Commit 68409b9

Browse files
committed
foo
Signed-off-by: Benoit Jacob <[email protected]>
1 parent d2b3912 commit 68409b9

File tree

3 files changed

+153
-67
lines changed

3 files changed

+153
-67
lines changed

mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,18 @@ class Pass;
1818
#define GEN_PASS_DECL_CONVERTMATHTOROCDL
1919
#include "mlir/Conversion/Passes.h.inc"
2020

21+
enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings };
22+
2123
/// Populate the given list with patterns that convert from Math to ROCDL calls.
22-
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
23-
RewritePatternSet &patterns);
24+
///
25+
/// Note that the default parameter value MathToROCDLConversionPatternKind::All
26+
/// is only for compatibility but is not recommended, because lumping together
27+
/// multiple conversion patters in the same pattern application can result in
28+
/// type conversion failures when one of the patterns failed.
29+
void populateMathToROCDLConversionPatterns(
30+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
31+
MathToROCDLConversionPatternKind patternKind =
32+
MathToROCDLConversionPatternKind::All);
2433
} // namespace mlir
2534

2635
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp

Lines changed: 125 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,25 @@ using namespace mlir;
3737

3838
template <typename OpTy>
3939
static void populateOpPatterns(const LLVMTypeConverter &converter,
40-
RewritePatternSet &patterns, StringRef f32Func,
41-
StringRef f64Func, StringRef f16Func,
40+
RewritePatternSet &patterns,
41+
MathToROCDLConversionPatternKind patternKind,
42+
StringRef f32Func, StringRef f64Func,
43+
StringRef f16Func,
4244
StringRef f32ApproxFunc = "") {
43-
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
44-
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45-
f32ApproxFunc, f16Func);
45+
if (patternKind == MathToROCDLConversionPatternKind::All ||
46+
patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
47+
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
48+
}
49+
if (patternKind == MathToROCDLConversionPatternKind::All ||
50+
patternKind == MathToROCDLConversionPatternKind::Lowerings) {
51+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
52+
f32ApproxFunc, f16Func);
53+
}
4654
}
4755

4856
void mlir::populateMathToROCDLConversionPatterns(
49-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
57+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
58+
MathToROCDLConversionPatternKind patternKind) {
5059
// Handled by mathToLLVM: math::AbsIOp
5160
// Handled by mathToLLVM: math::AbsFOp
5261
// Handled by mathToLLVM: math::CopySignOp
@@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
6170
// Handled by mathToLLVM: math::RoundOp
6271
// Handled by mathToLLVM: math::SqrtOp
6372
// Handled by mathToLLVM: math::TruncOp
64-
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
65-
"__ocml_acos_f64", "__ocml_acos_f16");
66-
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
67-
"__ocml_acosh_f64", "__ocml_acosh_f16");
68-
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
69-
"__ocml_asin_f64", "__ocml_asin_f16");
70-
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
71-
"__ocml_asinh_f64", "__ocml_asinh_f16");
72-
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
73-
"__ocml_atan_f64", "__ocml_atan_f16");
74-
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
75-
"__ocml_atanh_f64", "__ocml_atanh_f16");
76-
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
77-
"__ocml_atan2_f64", "__ocml_atan2_f16");
78-
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
79-
"__ocml_cbrt_f64", "__ocml_cbrt_f16");
80-
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
81-
"__ocml_ceil_f64", "__ocml_ceil_f16");
82-
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
83-
"__ocml_cos_f64", "__ocml_cos_f16");
84-
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
85-
"__ocml_cosh_f64", "__ocml_cosh_f16");
86-
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
87-
"__ocml_sinh_f64", "__ocml_sinh_f16");
88-
populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
89-
"__ocml_exp_f16");
90-
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
91-
"__ocml_exp2_f64", "__ocml_exp2_f16");
92-
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
93-
"__ocml_expm1_f64", "__ocml_expm1_f16");
94-
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
95-
"__ocml_floor_f64", "__ocml_floor_f16");
96-
populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
97-
"__ocml_log_f16");
98-
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
99-
"__ocml_log10_f64", "__ocml_log10_f16");
100-
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
101-
"__ocml_log1p_f64", "__ocml_log1p_f16");
102-
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
103-
"__ocml_log2_f64", "__ocml_log2_f16");
104-
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
105-
"__ocml_pow_f64", "__ocml_pow_f16");
106-
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
107-
"__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
108-
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
109-
"__ocml_sin_f64", "__ocml_sin_f16");
110-
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
111-
"__ocml_tanh_f64", "__ocml_tanh_f16");
112-
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
113-
"__ocml_tan_f64", "__ocml_tan_f16");
114-
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
115-
"__ocml_erf_f64", "__ocml_erf_f16");
116-
populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
117-
"__ocml_pown_f64", "__ocml_pown_f16");
73+
populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
74+
"__ocml_acos_f32", "__ocml_acos_f64",
75+
"__ocml_acos_f16");
76+
populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
77+
"__ocml_acosh_f32", "__ocml_acosh_f64",
78+
"__ocml_acosh_f16");
79+
populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
80+
"__ocml_asin_f32", "__ocml_asin_f64",
81+
"__ocml_asin_f16");
82+
populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
83+
"__ocml_asinh_f32", "__ocml_asinh_f64",
84+
"__ocml_asinh_f16");
85+
populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
86+
"__ocml_atan_f32", "__ocml_atan_f64",
87+
"__ocml_atan_f16");
88+
populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
89+
"__ocml_atanh_f32", "__ocml_atanh_f64",
90+
"__ocml_atanh_f16");
91+
populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
92+
"__ocml_atan2_f32", "__ocml_atan2_f64",
93+
"__ocml_atan2_f16");
94+
populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
95+
"__ocml_cbrt_f32", "__ocml_cbrt_f64",
96+
"__ocml_cbrt_f16");
97+
populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
98+
"__ocml_ceil_f32", "__ocml_ceil_f64",
99+
"__ocml_ceil_f16");
100+
populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
101+
"__ocml_cos_f32", "__ocml_cos_f64",
102+
"__ocml_cos_f16");
103+
populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
104+
"__ocml_cosh_f32", "__ocml_cosh_f64",
105+
"__ocml_cosh_f16");
106+
populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
107+
"__ocml_sinh_f32", "__ocml_sinh_f64",
108+
"__ocml_sinh_f16");
109+
populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, "",
110+
"__ocml_exp_f64", "__ocml_exp_f16");
111+
populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
112+
"__ocml_exp2_f32", "__ocml_exp2_f64",
113+
"__ocml_exp2_f16");
114+
populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
115+
"__ocml_expm1_f32", "__ocml_expm1_f64",
116+
"__ocml_expm1_f16");
117+
populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
118+
"__ocml_floor_f32", "__ocml_floor_f64",
119+
"__ocml_floor_f16");
120+
populateOpPatterns<math::LogOp>(converter, patterns, patternKind, "",
121+
"__ocml_log_f64", "__ocml_log_f16");
122+
populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
123+
"__ocml_log10_f32", "__ocml_log10_f64",
124+
"__ocml_log10_f16");
125+
populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
126+
"__ocml_log1p_f32", "__ocml_log1p_f64",
127+
"__ocml_log1p_f16");
128+
populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
129+
"__ocml_log2_f32", "__ocml_log2_f64",
130+
"__ocml_log2_f16");
131+
populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
132+
"__ocml_pow_f32", "__ocml_pow_f64",
133+
"__ocml_pow_f16");
134+
populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
135+
"__ocml_rsqrt_f32", "__ocml_rsqrt_f64",
136+
"__ocml_rsqrt_f16");
137+
populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
138+
"__ocml_sin_f32", "__ocml_sin_f64",
139+
"__ocml_sin_f16");
140+
populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
141+
"__ocml_tanh_f32", "__ocml_tanh_f64",
142+
"__ocml_tanh_f16");
143+
populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
144+
"__ocml_tan_f32", "__ocml_tan_f64",
145+
"__ocml_tan_f16");
146+
populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
147+
"__ocml_erf_f32", "__ocml_erf_f64",
148+
"__ocml_erf_f16");
149+
populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
150+
"__ocml_pown_f32", "__ocml_pown_f64",
151+
"__ocml_pown_f16");
118152
// Single arith pattern that needs a ROCDL call, probably not
119153
// worth creating a separate pass for it.
120-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
121-
"__ocml_fmod_f64", "__ocml_fmod_f16");
154+
populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
155+
"__ocml_fmod_f32", "__ocml_fmod_f64",
156+
"__ocml_fmod_f16");
122157
}
123158

124159
namespace {
@@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
133168
auto m = getOperation();
134169
MLIRContext *ctx = m.getContext();
135170

136-
RewritePatternSet patterns(&getContext());
137171
LowerToLLVMOptions options(ctx, DataLayout(m));
138172
LLVMTypeConverter converter(ctx, options);
139-
populateMathToROCDLConversionPatterns(converter, patterns);
173+
174+
// The two pattern applications below will use distinct ConversionTarget's,
175+
// but this is the common denominator.
140176
ConversionTarget target(getContext());
141177
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
142178
vector::VectorDialect, LLVM::LLVMDialect>();
179+
180+
// Perform the scalarizations. This is done in a separate pattern application
181+
// to ensure that scalarizations are done regardless of lowerings. It is
182+
// normal for some lowerings may fail to apply, when we purposely do not lower
183+
// a math op to a function call.
184+
RewritePatternSet scalarizationPatterns(&getContext());
185+
ConversionTarget scalarizationTarget(target);
186+
// Math ops are legal if their operands are not vectors.
187+
scalarizationTarget.addDynamicallyLegalDialect<math::MathDialect>(
188+
[&](Operation *op) {
189+
return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
190+
});
191+
populateMathToROCDLConversionPatterns(
192+
converter, scalarizationPatterns,
193+
MathToROCDLConversionPatternKind::Scalarizations);
194+
if (failed(applyPartialConversion(m, scalarizationTarget,
195+
std::move(scalarizationPatterns))))
196+
signalPassFailure();
197+
198+
// Perform the lowerings. The ops that must lower to function calls become
199+
// illegal.
143200
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
144201
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
145202
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
146203
LLVM::SqrtOp>();
147-
if (failed(applyPartialConversion(m, target, std::move(patterns))))
204+
RewritePatternSet loweringPatterns(&getContext());
205+
populateMathToROCDLConversionPatterns(
206+
converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
207+
if (failed(applyPartialConversion(m, target, std::move(loweringPatterns))))
148208
signalPassFailure();
149209
}

mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,20 @@ module @test_module {
578578
func.return %result : vector<2x2xf16>
579579
}
580580
}
581+
582+
// -----
583+
584+
module @test_module {
585+
// This test case covers the case of math ops that do not have a lowering to
586+
// a function call. When lowerings to call were lumped together with
587+
// scalarization in the same pattern application, they were preventing
588+
// scalarization.
589+
// CHECK-LABEL: func @math_log_f32_vector_0d
590+
func.func @math_log_f32_vector_0d(%arg : vector<f32>) -> vector<f32> {
591+
// CHECK: llvm.extractelement {{.*}} : vector<1xf32>
592+
// CHECK: math.log {{.*}} : f32
593+
// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
594+
%result = math.log %arg : vector<f32>
595+
func.return %result : vector<f32>
596+
}
597+
}

0 commit comments

Comments
 (0)