From 68409b900f56a373168952bb9dcfb2ea0ad7cb00 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 21 Feb 2025 10:10:51 -0600 Subject: [PATCH] foo Signed-off-by: Benoit Jacob --- .../mlir/Conversion/MathToROCDL/MathToROCDL.h | 13 +- .../Conversion/MathToROCDL/MathToROCDL.cpp | 190 ++++++++++++------ .../Conversion/MathToROCDL/math-to-rocdl.mlir | 17 ++ 3 files changed, 153 insertions(+), 67 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 46573e7966ccc..7d5c487a9dbff 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -18,9 +18,18 @@ class Pass; #define GEN_PASS_DECL_CONVERTMATHTOROCDL #include "mlir/Conversion/Passes.h.inc" +enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings }; + /// Populate the given list with patterns that convert from Math to ROCDL calls. -void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); +/// +/// Note that the default parameter value MathToROCDLConversionPatternKind::All +/// is only for compatibility but is not recommended, because lumping together +/// multiple conversion patters in the same pattern application can result in +/// type conversion failures when one of the patterns failed. +void populateMathToROCDLConversionPatterns( + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + MathToROCDLConversionPatternKind patternKind = + MathToROCDLConversionPatternKind::All); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 838eef30a938f..bd8578d70c260 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -37,16 +37,25 @@ using namespace mlir; template static void populateOpPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, StringRef f32Func, - StringRef f64Func, StringRef f16Func, + RewritePatternSet &patterns, + MathToROCDLConversionPatternKind patternKind, + StringRef f32Func, StringRef f64Func, + StringRef f16Func, StringRef f32ApproxFunc = "") { - patterns.add>(converter); - patterns.add>(converter, f32Func, f64Func, - f32ApproxFunc, f16Func); + if (patternKind == MathToROCDLConversionPatternKind::All || + patternKind == MathToROCDLConversionPatternKind::Scalarizations) { + patterns.add>(converter); + } + if (patternKind == MathToROCDLConversionPatternKind::All || + patternKind == MathToROCDLConversionPatternKind::Lowerings) { + patterns.add>(converter, f32Func, f64Func, + f32ApproxFunc, f16Func); + } } void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + MathToROCDLConversionPatternKind patternKind) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns( // Handled by mathToLLVM: math::RoundOp // Handled by mathToLLVM: math::SqrtOp // Handled by mathToLLVM: math::TruncOp - populateOpPatterns(converter, patterns, "__ocml_acos_f32", - "__ocml_acos_f64", "__ocml_acos_f16"); - populateOpPatterns(converter, patterns, "__ocml_acosh_f32", - "__ocml_acosh_f64", "__ocml_acosh_f16"); - populateOpPatterns(converter, patterns, "__ocml_asin_f32", - "__ocml_asin_f64", "__ocml_asin_f16"); - populateOpPatterns(converter, patterns, "__ocml_asinh_f32", - "__ocml_asinh_f64", "__ocml_asinh_f16"); - populateOpPatterns(converter, patterns, "__ocml_atan_f32", - "__ocml_atan_f64", "__ocml_atan_f16"); - populateOpPatterns(converter, patterns, "__ocml_atanh_f32", - "__ocml_atanh_f64", "__ocml_atanh_f16"); - populateOpPatterns(converter, patterns, "__ocml_atan2_f32", - "__ocml_atan2_f64", "__ocml_atan2_f16"); - populateOpPatterns(converter, patterns, "__ocml_cbrt_f32", - "__ocml_cbrt_f64", "__ocml_cbrt_f16"); - populateOpPatterns(converter, patterns, "__ocml_ceil_f32", - "__ocml_ceil_f64", "__ocml_ceil_f16"); - populateOpPatterns(converter, patterns, "__ocml_cos_f32", - "__ocml_cos_f64", "__ocml_cos_f16"); - populateOpPatterns(converter, patterns, "__ocml_cosh_f32", - "__ocml_cosh_f64", "__ocml_cosh_f16"); - populateOpPatterns(converter, patterns, "__ocml_sinh_f32", - "__ocml_sinh_f64", "__ocml_sinh_f16"); - populateOpPatterns(converter, patterns, "", "__ocml_exp_f64", - "__ocml_exp_f16"); - populateOpPatterns(converter, patterns, "__ocml_exp2_f32", - "__ocml_exp2_f64", "__ocml_exp2_f16"); - populateOpPatterns(converter, patterns, "__ocml_expm1_f32", - "__ocml_expm1_f64", "__ocml_expm1_f16"); - populateOpPatterns(converter, patterns, "__ocml_floor_f32", - "__ocml_floor_f64", "__ocml_floor_f16"); - populateOpPatterns(converter, patterns, "", "__ocml_log_f64", - "__ocml_log_f16"); - populateOpPatterns(converter, patterns, "__ocml_log10_f32", - "__ocml_log10_f64", "__ocml_log10_f16"); - populateOpPatterns(converter, patterns, "__ocml_log1p_f32", - "__ocml_log1p_f64", "__ocml_log1p_f16"); - populateOpPatterns(converter, patterns, "__ocml_log2_f32", - "__ocml_log2_f64", "__ocml_log2_f16"); - populateOpPatterns(converter, patterns, "__ocml_pow_f32", - "__ocml_pow_f64", "__ocml_pow_f16"); - populateOpPatterns(converter, patterns, "__ocml_rsqrt_f32", - "__ocml_rsqrt_f64", "__ocml_rsqrt_f16"); - populateOpPatterns(converter, patterns, "__ocml_sin_f32", - "__ocml_sin_f64", "__ocml_sin_f16"); - populateOpPatterns(converter, patterns, "__ocml_tanh_f32", - "__ocml_tanh_f64", "__ocml_tanh_f16"); - populateOpPatterns(converter, patterns, "__ocml_tan_f32", - "__ocml_tan_f64", "__ocml_tan_f16"); - populateOpPatterns(converter, patterns, "__ocml_erf_f32", - "__ocml_erf_f64", "__ocml_erf_f16"); - populateOpPatterns(converter, patterns, "__ocml_pown_f32", - "__ocml_pown_f64", "__ocml_pown_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_acos_f32", "__ocml_acos_f64", + "__ocml_acos_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_acosh_f32", "__ocml_acosh_f64", + "__ocml_acosh_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_asin_f32", "__ocml_asin_f64", + "__ocml_asin_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_asinh_f32", "__ocml_asinh_f64", + "__ocml_asinh_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_atan_f32", "__ocml_atan_f64", + "__ocml_atan_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_atanh_f32", "__ocml_atanh_f64", + "__ocml_atanh_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_atan2_f32", "__ocml_atan2_f64", + "__ocml_atan2_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_cbrt_f32", "__ocml_cbrt_f64", + "__ocml_cbrt_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_ceil_f32", "__ocml_ceil_f64", + "__ocml_ceil_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_cos_f32", "__ocml_cos_f64", + "__ocml_cos_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_cosh_f32", "__ocml_cosh_f64", + "__ocml_cosh_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_sinh_f32", "__ocml_sinh_f64", + "__ocml_sinh_f16"); + populateOpPatterns(converter, patterns, patternKind, "", + "__ocml_exp_f64", "__ocml_exp_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_exp2_f32", "__ocml_exp2_f64", + "__ocml_exp2_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_expm1_f32", "__ocml_expm1_f64", + "__ocml_expm1_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_floor_f32", "__ocml_floor_f64", + "__ocml_floor_f16"); + populateOpPatterns(converter, patterns, patternKind, "", + "__ocml_log_f64", "__ocml_log_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_log10_f32", "__ocml_log10_f64", + "__ocml_log10_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_log1p_f32", "__ocml_log1p_f64", + "__ocml_log1p_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_log2_f32", "__ocml_log2_f64", + "__ocml_log2_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_pow_f32", "__ocml_pow_f64", + "__ocml_pow_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_rsqrt_f32", "__ocml_rsqrt_f64", + "__ocml_rsqrt_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_sin_f32", "__ocml_sin_f64", + "__ocml_sin_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_tanh_f32", "__ocml_tanh_f64", + "__ocml_tanh_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_tan_f32", "__ocml_tan_f64", + "__ocml_tan_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_erf_f32", "__ocml_erf_f64", + "__ocml_erf_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_pown_f32", "__ocml_pown_f64", + "__ocml_pown_f16"); // Single arith pattern that needs a ROCDL call, probably not // worth creating a separate pass for it. - populateOpPatterns(converter, patterns, "__ocml_fmod_f32", - "__ocml_fmod_f64", "__ocml_fmod_f16"); + populateOpPatterns(converter, patterns, patternKind, + "__ocml_fmod_f32", "__ocml_fmod_f64", + "__ocml_fmod_f16"); } namespace { @@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); MLIRContext *ctx = m.getContext(); - RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns); + + // The two pattern applications below will use distinct ConversionTarget's, + // but this is the common denominator. ConversionTarget target(getContext()); target.addLegalDialect(); + + // Perform the scalarizations. This is done in a separate pattern application + // to ensure that scalarizations are done regardless of lowerings. It is + // normal for some lowerings may fail to apply, when we purposely do not lower + // a math op to a function call. + RewritePatternSet scalarizationPatterns(&getContext()); + ConversionTarget scalarizationTarget(target); + // Math ops are legal if their operands are not vectors. + scalarizationTarget.addDynamicallyLegalDialect( + [&](Operation *op) { + return llvm::none_of(op->getOperandTypes(), llvm::IsaPred); + }); + populateMathToROCDLConversionPatterns( + converter, scalarizationPatterns, + MathToROCDLConversionPatternKind::Scalarizations); + if (failed(applyPartialConversion(m, scalarizationTarget, + std::move(scalarizationPatterns)))) + signalPassFailure(); + + // Perform the lowerings. The ops that must lower to function calls become + // illegal. target.addIllegalOp(); - if (failed(applyPartialConversion(m, target, std::move(patterns)))) + RewritePatternSet loweringPatterns(&getContext()); + populateMathToROCDLConversionPatterns( + converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings); + if (failed(applyPartialConversion(m, target, std::move(loweringPatterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 313d7b086731e..44ee2fcbcb7f8 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -578,3 +578,20 @@ module @test_module { func.return %result : vector<2x2xf16> } } + +// ----- + +module @test_module { + // This test case covers the case of math ops that do not have a lowering to + // a function call. When lowerings to call were lumped together with + // scalarization in the same pattern application, they were preventing + // scalarization. + // CHECK-LABEL: func @math_log_f32_vector_0d + func.func @math_log_f32_vector_0d(%arg : vector) -> vector { + // CHECK: llvm.extractelement {{.*}} : vector<1xf32> + // CHECK: math.log {{.*}} : f32 + // CHECK: llvm.insertelement {{.*}} : vector<1xf32> + %result = math.log %arg : vector + func.return %result : vector + } +}