@@ -37,16 +37,25 @@ using namespace mlir;
3737
3838template <typename OpTy>
3939static void populateOpPatterns (const LLVMTypeConverter &converter,
40- RewritePatternSet &patterns, StringRef f32Func,
41- StringRef f64Func, StringRef f16Func,
40+ RewritePatternSet &patterns,
41+ MathToROCDLConversionPatternKind patternKind,
42+ StringRef f32Func, StringRef f64Func,
43+ StringRef f16Func,
4244 StringRef f32ApproxFunc = " " ) {
43- patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
44- patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45- f32ApproxFunc, f16Func);
45+ if (patternKind == MathToROCDLConversionPatternKind::All ||
46+ patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
47+ patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
48+ }
49+ if (patternKind == MathToROCDLConversionPatternKind::All ||
50+ patternKind == MathToROCDLConversionPatternKind::Lowerings) {
51+ patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
52+ f32ApproxFunc, f16Func);
53+ }
4654}
4755
4856void mlir::populateMathToROCDLConversionPatterns (
49- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
57+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
58+ MathToROCDLConversionPatternKind patternKind) {
5059 // Handled by mathToLLVM: math::AbsIOp
5160 // Handled by mathToLLVM: math::AbsFOp
5261 // Handled by mathToLLVM: math::CopySignOp
@@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
6170 // Handled by mathToLLVM: math::RoundOp
6271 // Handled by mathToLLVM: math::SqrtOp
6372 // Handled by mathToLLVM: math::TruncOp
64- populateOpPatterns<math::AcosOp>(converter, patterns, " __ocml_acos_f32" ,
65- " __ocml_acos_f64" , " __ocml_acos_f16" );
66- populateOpPatterns<math::AcoshOp>(converter, patterns, " __ocml_acosh_f32" ,
67- " __ocml_acosh_f64" , " __ocml_acosh_f16" );
68- populateOpPatterns<math::AsinOp>(converter, patterns, " __ocml_asin_f32" ,
69- " __ocml_asin_f64" , " __ocml_asin_f16" );
70- populateOpPatterns<math::AsinhOp>(converter, patterns, " __ocml_asinh_f32" ,
71- " __ocml_asinh_f64" , " __ocml_asinh_f16" );
72- populateOpPatterns<math::AtanOp>(converter, patterns, " __ocml_atan_f32" ,
73- " __ocml_atan_f64" , " __ocml_atan_f16" );
74- populateOpPatterns<math::AtanhOp>(converter, patterns, " __ocml_atanh_f32" ,
75- " __ocml_atanh_f64" , " __ocml_atanh_f16" );
76- populateOpPatterns<math::Atan2Op>(converter, patterns, " __ocml_atan2_f32" ,
77- " __ocml_atan2_f64" , " __ocml_atan2_f16" );
78- populateOpPatterns<math::CbrtOp>(converter, patterns, " __ocml_cbrt_f32" ,
79- " __ocml_cbrt_f64" , " __ocml_cbrt_f16" );
80- populateOpPatterns<math::CeilOp>(converter, patterns, " __ocml_ceil_f32" ,
81- " __ocml_ceil_f64" , " __ocml_ceil_f16" );
82- populateOpPatterns<math::CosOp>(converter, patterns, " __ocml_cos_f32" ,
83- " __ocml_cos_f64" , " __ocml_cos_f16" );
84- populateOpPatterns<math::CoshOp>(converter, patterns, " __ocml_cosh_f32" ,
85- " __ocml_cosh_f64" , " __ocml_cosh_f16" );
86- populateOpPatterns<math::SinhOp>(converter, patterns, " __ocml_sinh_f32" ,
87- " __ocml_sinh_f64" , " __ocml_sinh_f16" );
88- populateOpPatterns<math::ExpOp>(converter, patterns, " " , " __ocml_exp_f64" ,
89- " __ocml_exp_f16" );
90- populateOpPatterns<math::Exp2Op>(converter, patterns, " __ocml_exp2_f32" ,
91- " __ocml_exp2_f64" , " __ocml_exp2_f16" );
92- populateOpPatterns<math::ExpM1Op>(converter, patterns, " __ocml_expm1_f32" ,
93- " __ocml_expm1_f64" , " __ocml_expm1_f16" );
94- populateOpPatterns<math::FloorOp>(converter, patterns, " __ocml_floor_f32" ,
95- " __ocml_floor_f64" , " __ocml_floor_f16" );
96- populateOpPatterns<math::LogOp>(converter, patterns, " " , " __ocml_log_f64" ,
97- " __ocml_log_f16" );
98- populateOpPatterns<math::Log10Op>(converter, patterns, " __ocml_log10_f32" ,
99- " __ocml_log10_f64" , " __ocml_log10_f16" );
100- populateOpPatterns<math::Log1pOp>(converter, patterns, " __ocml_log1p_f32" ,
101- " __ocml_log1p_f64" , " __ocml_log1p_f16" );
102- populateOpPatterns<math::Log2Op>(converter, patterns, " __ocml_log2_f32" ,
103- " __ocml_log2_f64" , " __ocml_log2_f16" );
104- populateOpPatterns<math::PowFOp>(converter, patterns, " __ocml_pow_f32" ,
105- " __ocml_pow_f64" , " __ocml_pow_f16" );
106- populateOpPatterns<math::RsqrtOp>(converter, patterns, " __ocml_rsqrt_f32" ,
107- " __ocml_rsqrt_f64" , " __ocml_rsqrt_f16" );
108- populateOpPatterns<math::SinOp>(converter, patterns, " __ocml_sin_f32" ,
109- " __ocml_sin_f64" , " __ocml_sin_f16" );
110- populateOpPatterns<math::TanhOp>(converter, patterns, " __ocml_tanh_f32" ,
111- " __ocml_tanh_f64" , " __ocml_tanh_f16" );
112- populateOpPatterns<math::TanOp>(converter, patterns, " __ocml_tan_f32" ,
113- " __ocml_tan_f64" , " __ocml_tan_f16" );
114- populateOpPatterns<math::ErfOp>(converter, patterns, " __ocml_erf_f32" ,
115- " __ocml_erf_f64" , " __ocml_erf_f16" );
116- populateOpPatterns<math::FPowIOp>(converter, patterns, " __ocml_pown_f32" ,
117- " __ocml_pown_f64" , " __ocml_pown_f16" );
73+ populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
74+ " __ocml_acos_f32" , " __ocml_acos_f64" ,
75+ " __ocml_acos_f16" );
76+ populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
77+ " __ocml_acosh_f32" , " __ocml_acosh_f64" ,
78+ " __ocml_acosh_f16" );
79+ populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
80+ " __ocml_asin_f32" , " __ocml_asin_f64" ,
81+ " __ocml_asin_f16" );
82+ populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
83+ " __ocml_asinh_f32" , " __ocml_asinh_f64" ,
84+ " __ocml_asinh_f16" );
85+ populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
86+ " __ocml_atan_f32" , " __ocml_atan_f64" ,
87+ " __ocml_atan_f16" );
88+ populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
89+ " __ocml_atanh_f32" , " __ocml_atanh_f64" ,
90+ " __ocml_atanh_f16" );
91+ populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
92+ " __ocml_atan2_f32" , " __ocml_atan2_f64" ,
93+ " __ocml_atan2_f16" );
94+ populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
95+ " __ocml_cbrt_f32" , " __ocml_cbrt_f64" ,
96+ " __ocml_cbrt_f16" );
97+ populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
98+ " __ocml_ceil_f32" , " __ocml_ceil_f64" ,
99+ " __ocml_ceil_f16" );
100+ populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
101+ " __ocml_cos_f32" , " __ocml_cos_f64" ,
102+ " __ocml_cos_f16" );
103+ populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
104+ " __ocml_cosh_f32" , " __ocml_cosh_f64" ,
105+ " __ocml_cosh_f16" );
106+ populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
107+ " __ocml_sinh_f32" , " __ocml_sinh_f64" ,
108+ " __ocml_sinh_f16" );
109+ populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, " " ,
110+ " __ocml_exp_f64" , " __ocml_exp_f16" );
111+ populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
112+ " __ocml_exp2_f32" , " __ocml_exp2_f64" ,
113+ " __ocml_exp2_f16" );
114+ populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
115+ " __ocml_expm1_f32" , " __ocml_expm1_f64" ,
116+ " __ocml_expm1_f16" );
117+ populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
118+ " __ocml_floor_f32" , " __ocml_floor_f64" ,
119+ " __ocml_floor_f16" );
120+ populateOpPatterns<math::LogOp>(converter, patterns, patternKind, " " ,
121+ " __ocml_log_f64" , " __ocml_log_f16" );
122+ populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
123+ " __ocml_log10_f32" , " __ocml_log10_f64" ,
124+ " __ocml_log10_f16" );
125+ populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
126+ " __ocml_log1p_f32" , " __ocml_log1p_f64" ,
127+ " __ocml_log1p_f16" );
128+ populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
129+ " __ocml_log2_f32" , " __ocml_log2_f64" ,
130+ " __ocml_log2_f16" );
131+ populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
132+ " __ocml_pow_f32" , " __ocml_pow_f64" ,
133+ " __ocml_pow_f16" );
134+ populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
135+ " __ocml_rsqrt_f32" , " __ocml_rsqrt_f64" ,
136+ " __ocml_rsqrt_f16" );
137+ populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
138+ " __ocml_sin_f32" , " __ocml_sin_f64" ,
139+ " __ocml_sin_f16" );
140+ populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
141+ " __ocml_tanh_f32" , " __ocml_tanh_f64" ,
142+ " __ocml_tanh_f16" );
143+ populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
144+ " __ocml_tan_f32" , " __ocml_tan_f64" ,
145+ " __ocml_tan_f16" );
146+ populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
147+ " __ocml_erf_f32" , " __ocml_erf_f64" ,
148+ " __ocml_erf_f16" );
149+ populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
150+ " __ocml_pown_f32" , " __ocml_pown_f64" ,
151+ " __ocml_pown_f16" );
118152 // Single arith pattern that needs a ROCDL call, probably not
119153 // worth creating a separate pass for it.
120- populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
121- " __ocml_fmod_f64" , " __ocml_fmod_f16" );
154+ populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
155+ " __ocml_fmod_f32" , " __ocml_fmod_f64" ,
156+ " __ocml_fmod_f16" );
122157}
123158
124159namespace {
@@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
133168 auto m = getOperation ();
134169 MLIRContext *ctx = m.getContext ();
135170
136- RewritePatternSet patterns (&getContext ());
137171 LowerToLLVMOptions options (ctx, DataLayout (m));
138172 LLVMTypeConverter converter (ctx, options);
139- populateMathToROCDLConversionPatterns (converter, patterns);
173+
174+ // The two pattern applications below will use distinct ConversionTarget's,
175+ // but this is the common denominator.
140176 ConversionTarget target (getContext ());
141177 target.addLegalDialect <BuiltinDialect, func::FuncDialect,
142178 vector::VectorDialect, LLVM::LLVMDialect>();
179+
180+ // Perform the scalarizations. This is done in a separate pattern application
181+ // to ensure that scalarizations are done regardless of lowerings. It is
182+ // normal for some lowerings may fail to apply, when we purposely do not lower
183+ // a math op to a function call.
184+ RewritePatternSet scalarizationPatterns (&getContext ());
185+ ConversionTarget scalarizationTarget (target);
186+ // Math ops are legal if their operands are not vectors.
187+ scalarizationTarget.addDynamicallyLegalDialect <math::MathDialect>(
188+ [&](Operation *op) {
189+ return llvm::none_of (op->getOperandTypes (), llvm::IsaPred<VectorType>);
190+ });
191+ populateMathToROCDLConversionPatterns (
192+ converter, scalarizationPatterns,
193+ MathToROCDLConversionPatternKind::Scalarizations);
194+ if (failed (applyPartialConversion (m, scalarizationTarget,
195+ std::move (scalarizationPatterns))))
196+ signalPassFailure ();
197+
198+ // Perform the lowerings. The ops that must lower to function calls become
199+ // illegal.
143200 target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
144201 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
145202 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
146203 LLVM::SqrtOp>();
147- if (failed (applyPartialConversion (m, target, std::move (patterns))))
204+ RewritePatternSet loweringPatterns (&getContext ());
205+ populateMathToROCDLConversionPatterns (
206+ converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
207+ if (failed (applyPartialConversion (m, target, std::move (loweringPatterns))))
148208 signalPassFailure ();
149209}
0 commit comments