@@ -38,17 +38,17 @@ using namespace mlir;
3838template <typename OpTy>
3939static void populateOpPatterns (LLVMTypeConverter &converter,
4040 RewritePatternSet &patterns, StringRef f32Func,
41- StringRef f64Func,
41+ StringRef f64Func, StringRef f16Func,
4242 StringRef f32ApproxFunc = " " ) {
4343 patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
4444 patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
45- f32ApproxFunc);
45+ f32ApproxFunc, f16Func );
4646}
4747
4848void mlir::populateMathToROCDLConversionPatterns (LLVMTypeConverter &converter,
4949 RewritePatternSet &patterns) {
5050 // Handled by mathToLLVM: math::AbsIOp
51- // Handled by mathToLLVM: math::AbsFIOp
51+ // Handled by mathToLLVM: math::AbsFOp
5252 // Handled by mathToLLVM: math::CopySignOp
5353 // Handled by mathToLLVM: math::CountLeadingZerosOp
5454 // Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
6363 // Handled by mathToLLVM: math::SqrtOp
6464 // Handled by mathToLLVM: math::TruncOp
6565 populateOpPatterns<math::AcosOp>(converter, patterns, " __ocml_acos_f32" ,
66- " __ocml_acos_f64" );
66+ " __ocml_acos_f64" , " __ocml_acos_f16 " );
6767 populateOpPatterns<math::AcoshOp>(converter, patterns, " __ocml_acosh_f32" ,
68- " __ocml_acosh_f64" );
68+ " __ocml_acosh_f64" , " __ocml_acosh_f16 " );
6969 populateOpPatterns<math::AsinOp>(converter, patterns, " __ocml_asin_f32" ,
70- " __ocml_asin_f64" );
70+ " __ocml_asin_f64" , " __ocml_asin_f16 " );
7171 populateOpPatterns<math::AsinhOp>(converter, patterns, " __ocml_asinh_f32" ,
72- " __ocml_asinh_f64" );
72+ " __ocml_asinh_f64" , " __ocml_asinh_f16 " );
7373 populateOpPatterns<math::AtanOp>(converter, patterns, " __ocml_atan_f32" ,
74- " __ocml_atan_f64" );
74+ " __ocml_atan_f64" , " __ocml_atan_f16 " );
7575 populateOpPatterns<math::AtanhOp>(converter, patterns, " __ocml_atanh_f32" ,
76- " __ocml_atanh_f64" );
76+ " __ocml_atanh_f64" , " __ocml_atanh_f16 " );
7777 populateOpPatterns<math::Atan2Op>(converter, patterns, " __ocml_atan2_f32" ,
78- " __ocml_atan2_f64" );
78+ " __ocml_atan2_f64" , " __ocml_atan2_f16 " );
7979 populateOpPatterns<math::CbrtOp>(converter, patterns, " __ocml_cbrt_f32" ,
80- " __ocml_cbrt_f64" );
80+ " __ocml_cbrt_f64" , " __ocml_cbrt_f16 " );
8181 populateOpPatterns<math::CeilOp>(converter, patterns, " __ocml_ceil_f32" ,
82- " __ocml_ceil_f64" );
82+ " __ocml_ceil_f64" , " __ocml_ceil_f16 " );
8383 populateOpPatterns<math::CosOp>(converter, patterns, " __ocml_cos_f32" ,
84- " __ocml_cos_f64" );
84+ " __ocml_cos_f64" , " __ocml_cos_f16 " );
8585 populateOpPatterns<math::CoshOp>(converter, patterns, " __ocml_cosh_f32" ,
86- " __ocml_cosh_f64" );
86+ " __ocml_cosh_f64" , " __ocml_cosh_f16 " );
8787 populateOpPatterns<math::SinhOp>(converter, patterns, " __ocml_sinh_f32" ,
88- " __ocml_sinh_f64" );
89- populateOpPatterns<math::ExpOp>(converter, patterns, " " , " __ocml_exp_f64" );
88+ " __ocml_sinh_f64" , " __ocml_sinh_f16" );
89+ populateOpPatterns<math::ExpOp>(converter, patterns, " " , " __ocml_exp_f64" ,
90+ " __ocml_exp_f16" );
9091 populateOpPatterns<math::Exp2Op>(converter, patterns, " __ocml_exp2_f32" ,
91- " __ocml_exp2_f64" );
92+ " __ocml_exp2_f64" , " __ocml_exp2_f16 " );
9293 populateOpPatterns<math::ExpM1Op>(converter, patterns, " __ocml_expm1_f32" ,
93- " __ocml_expm1_f64" );
94+ " __ocml_expm1_f64" , " __ocml_expm1_f16 " );
9495 populateOpPatterns<math::FloorOp>(converter, patterns, " __ocml_floor_f32" ,
95- " __ocml_floor_f64" );
96- populateOpPatterns<math::LogOp>(converter, patterns, " " , " __ocml_log_f64" );
96+ " __ocml_floor_f64" , " __ocml_floor_f16" );
97+ populateOpPatterns<math::LogOp>(converter, patterns, " " , " __ocml_log_f64" ,
98+ " __ocml_log_f16" );
9799 populateOpPatterns<math::Log10Op>(converter, patterns, " __ocml_log10_f32" ,
98- " __ocml_log10_f64" );
100+ " __ocml_log10_f64" , " __ocml_log10_f16 " );
99101 populateOpPatterns<math::Log1pOp>(converter, patterns, " __ocml_log1p_f32" ,
100- " __ocml_log1p_f64" );
102+ " __ocml_log1p_f64" , " __ocml_log1p_f16 " );
101103 populateOpPatterns<math::Log2Op>(converter, patterns, " __ocml_log2_f32" ,
102- " __ocml_log2_f64" );
104+ " __ocml_log2_f64" , " __ocml_log2_f16 " );
103105 populateOpPatterns<math::PowFOp>(converter, patterns, " __ocml_pow_f32" ,
104- " __ocml_pow_f64" );
106+ " __ocml_pow_f64" , " __ocml_pow_f16 " );
105107 populateOpPatterns<math::RsqrtOp>(converter, patterns, " __ocml_rsqrt_f32" ,
106- " __ocml_rsqrt_f64" );
108+ " __ocml_rsqrt_f64" , " __ocml_rsqrt_f16 " );
107109 populateOpPatterns<math::SinOp>(converter, patterns, " __ocml_sin_f32" ,
108- " __ocml_sin_f64" );
110+ " __ocml_sin_f64" , " __ocml_sin_f16 " );
109111 populateOpPatterns<math::TanhOp>(converter, patterns, " __ocml_tanh_f32" ,
110- " __ocml_tanh_f64" );
112+ " __ocml_tanh_f64" , " __ocml_tanh_f16 " );
111113 populateOpPatterns<math::TanOp>(converter, patterns, " __ocml_tan_f32" ,
112- " __ocml_tan_f64" );
114+ " __ocml_tan_f64" , " __ocml_tan_f16 " );
113115 populateOpPatterns<math::ErfOp>(converter, patterns, " __ocml_erf_f32" ,
114- " __ocml_erf_f64" );
116+ " __ocml_erf_f64" , " __ocml_erf_f16 " );
115117 // Single arith pattern that needs a ROCDL call, probably not
116118 // worth creating a separate pass for it.
117119 populateOpPatterns<arith::RemFOp>(converter, patterns, " __ocml_fmod_f32" ,
118- " __ocml_fmod_f64" );
120+ " __ocml_fmod_f64" , " __ocml_fmod_f16 " );
119121}
120122
121123namespace {
0 commit comments