From f5cf0218677a5019f78bfe451fcd343b19beb4c8 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 13 Oct 2025 12:55:36 -0700 Subject: [PATCH 1/9] [MLIR][ROCDL] Added math.clampf -> rocdl.fmed3 conversion Signed-off-by: Keshav Vinayak Jha --- .../mlir/Conversion/MathToROCDL/MathToROCDL.h | 4 +- mlir/include/mlir/Conversion/Passes.td | 8 + .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 2 +- .../Conversion/MathToROCDL/MathToROCDL.cpp | 54 +- .../Conversion/MathToROCDL/math-to-rocdl.mlir | 941 +++++++++++++----- 5 files changed, 745 insertions(+), 264 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 46573e7966ccc..770f257d89bd5 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/IR/PatternMatch.h" #include @@ -20,7 +21,8 @@ class Pass; /// Populate the given list with patterns that convert from Math to ROCDL calls. void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + amdgpu::Chipset chipset); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 3c18ecc753d0f..c3fd397e258ae 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -755,6 +755,14 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> { "func::FuncDialect", "vector::VectorDialect", ]; + let options = [ + Option<"chipset", "chipset", "std::string", + + + /*default=*/"\"gfx000\"", + "Chipset that these operations will run on"> + ]; + } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index b215211e131d4..c03f3a5d3889c 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns( GPUSubgroupBroadcastOpToROCDL>(converter); patterns.add(converter, chipset); - populateMathToROCDLConversionPatterns(converter, patterns); + populateMathToROCDLConversionPatterns(converter, patterns, chipset); } diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3ff4f6e..ceb3d22c6bd59 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -42,8 +43,39 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +struct ClampFOpConversion final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + ClampFOpConversion(const LLVMTypeConverter &converter, + amdgpu::Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // V_MED3_F16/F32 only exists in gfx9+ artchitectures + if (chipset.majorVersion < 9) { + return rewriter.notifyMatchFailure( + op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) + + "): V_MED_F16 / V_MED3_F32 not supported.")); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } + amdgpu::Chipset chipset; +}; + +static void addChipsetDependentPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + amdgpu::Chipset chipset) { + + patterns.add(converter, chipset); +} + void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + amdgpu::Chipset chipset) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -118,27 +150,31 @@ void mlir::populateMathToROCDLConversionPatterns( // worth creating a separate pass for it. populateOpPatterns(converter, patterns, "__ocml_fmod_f32", "__ocml_fmod_f64", "__ocml_fmod_f16"); + + addChipsetDependentPatterns(converter, patterns, chipset); } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase { - ConvertMathToROCDLPass() = default; +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); MLIRContext *ctx = m.getContext(); + FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns); + populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); target.addIllegalOp f16 // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 // CHECK-LABEL: func @arith_remf - func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = arith.remf %arg_f16, %arg_f16 : f16 - // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 - %result32 = arith.remf %arg_f32, %arg_f32 : f32 - // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 - %result64 = arith.remf %arg_f64, %arg_f64 : f64 - // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @arith_remf(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = arith.remf % arg_f16, + % + arg_f16 : f16 + // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : + // (f16, f16) -> f16 + % + result32 = arith.remf % arg_f32, + % + arg_f32 : f32 + // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : + // (f32, f32) -> f32 + % + result64 = arith.remf % arg_f64, + % + arg_f64 : f64 + // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : + // (f64, f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -23,14 +45,28 @@ module @test_module { // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64 // CHECK-LABEL: func @math_acos - func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.acos %arg_f16 : f16 - // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.acos %arg_f32 : f32 - // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.acos %arg_f64 : f64 - // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_acos(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.acos % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.acos % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.acos % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -41,14 +77,28 @@ module @test_module { // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64 // CHECK-LABEL: func @math_acosh - func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.acosh %arg_f16 : f16 - // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.acosh %arg_f32 : f32 - // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.acosh %arg_f64 : f64 - // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_acosh(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.acosh % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.acosh % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.acosh % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -59,14 +109,28 @@ module @test_module { // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64 // CHECK-LABEL: func @math_asin - func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.asin %arg_f16 : f16 - // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.asin %arg_f32 : f32 - // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.asin %arg_f64 : f64 - // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_asin(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.asin % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.asin % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.asin % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -77,14 +141,28 @@ module @test_module { // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64 // CHECK-LABEL: func @math_asinh - func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.asinh %arg_f16 : f16 - // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.asinh %arg_f32 : f32 - // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.asinh %arg_f64 : f64 - // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_asinh(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.asinh % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.asinh % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.asinh % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -95,14 +173,28 @@ module @test_module { // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64 // CHECK-LABEL: func @math_atan - func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.atan %arg_f16 : f16 - // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.atan %arg_f32 : f32 - // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.atan %arg_f64 : f64 - // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_atan(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.atan % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.atan % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.atan % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -113,14 +205,28 @@ module @test_module { // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64 // CHECK-LABEL: func @math_atanh - func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.atanh %arg_f16 : f16 - // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.atanh %arg_f32 : f32 - // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.atanh %arg_f64 : f64 - // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_atanh(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.atanh % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.atanh % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.atanh % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -131,14 +237,31 @@ module @test_module { // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64 // CHECK-LABEL: func @math_atan2 - func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.atan2 %arg_f16, %arg_f16 : f16 - // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 - %result32 = math.atan2 %arg_f32, %arg_f32 : f32 - // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 - %result64 = math.atan2 %arg_f64, %arg_f64 : f64 - // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_atan2(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.atan2 % arg_f16, + % + arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : + // (f16, f16) -> f16 + % + result32 = math.atan2 % arg_f32, + % + arg_f32 : f32 + // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : + // (f32, f32) -> f32 + % + result64 = math.atan2 % arg_f64, + % + arg_f64 : f64 + // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) + // : (f64, f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -149,14 +272,28 @@ module @test_module { // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 // CHECK-LABEL: func @math_cbrt - func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.cbrt %arg_f16 : f16 - // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.cbrt %arg_f32 : f32 - // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.cbrt %arg_f64 : f64 - // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_cbrt(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.cbrt % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.cbrt % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.cbrt % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -167,14 +304,28 @@ module @test_module { // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32 // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64 // CHECK-LABEL: func @math_ceil - func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.ceil %arg_f16 : f16 - // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.ceil %arg_f32 : f32 - // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.ceil %arg_f64 : f64 - // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_ceil(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.ceil % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.ceil % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.ceil % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -185,14 +336,28 @@ module @test_module { // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64 // CHECK-LABEL: func @math_cos - func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.cos %arg_f16 : f16 - // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.cos %arg_f32 : f32 - // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.cos %arg_f64 : f64 - // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_cos(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.cos % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.cos % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.cos % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -203,14 +368,28 @@ module @test_module { // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64 // CHECK-LABEL: func @math_cosh - func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.cosh %arg_f16 : f16 - // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.cosh %arg_f32 : f32 - // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.cosh %arg_f64 : f64 - // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_cosh(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.cosh % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.cosh % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.cosh % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -221,14 +400,28 @@ module @test_module { // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64 // CHECK-LABEL: func @math_sinh - func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.sinh %arg_f16 : f16 - // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.sinh %arg_f32 : f32 - // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.sinh %arg_f64 : f64 - // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_sinh(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.sinh % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.sinh % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.sinh % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -238,12 +431,18 @@ module @test_module { // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16 // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @math_exp - func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { - %result16 = math.exp %arg_f16 : f16 - // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16 - %result64 = math.exp %arg_f64 : f64 - // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result64 : f16, f64 + func.func @math_exp(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) { + % result16 = + math.exp % + arg_f16 : f16 + // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16 + % + result64 = math.exp % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 + func.return % result16, + % result64 : f16, f64 } } @@ -254,14 +453,28 @@ module @test_module { // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32 // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64 // CHECK-LABEL: func @math_exp2 - func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.exp2 %arg_f16 : f16 - // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.exp2 %arg_f32 : f32 - // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.exp2 %arg_f64 : f64 - // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_exp2(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.exp2 % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.exp2 % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.exp2 % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -272,14 +485,28 @@ module @test_module { // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32 // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64 // CHECK-LABEL: func @math_expm1 - func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.expm1 %arg_f16 : f16 - // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.expm1 %arg_f32 : f32 - // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.expm1 %arg_f64 : f64 - // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_expm1(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.expm1 % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.expm1 % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.expm1 % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -290,14 +517,28 @@ module @test_module { // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32 // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64 // CHECK-LABEL: func @math_floor - func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.floor %arg_f16 : f16 - // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.floor %arg_f32 : f32 - // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.floor %arg_f64 : f64 - // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_floor(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.floor % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.floor % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.floor % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -307,12 +548,18 @@ module @test_module { // CHECK: llvm.func @__ocml_log_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @math_log - func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { - %result16 = math.log %arg_f16 : f16 - // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 - %result64 = math.log %arg_f64 : f64 - // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result64 : f16, f64 + func.func @math_log(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) { + % result16 = + math.log % + arg_f16 : f16 + // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 + % + result64 = math.log % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 + func.return % result16, + % result64 : f16, f64 } } @@ -323,14 +570,28 @@ module @test_module { // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64 // CHECK-LABEL: func @math_log10 - func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.log10 %arg_f16 : f16 - // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.log10 %arg_f32 : f32 - // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.log10 %arg_f64 : f64 - // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_log10(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.log10 % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.log10 % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.log10 % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -341,14 +602,28 @@ module @test_module { // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64 // CHECK-LABEL: func @math_log1p - func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.log1p %arg_f16 : f16 - // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.log1p %arg_f32 : f32 - // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.log1p %arg_f64 : f64 - // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_log1p(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.log1p % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.log1p % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.log1p % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -359,14 +634,31 @@ module @test_module { // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64 // CHECK-LABEL: func @math_powf - func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.powf %arg_f16, %arg_f16 : f16 - // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 - %result32 = math.powf %arg_f32, %arg_f32 : f32 - // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 - %result64 = math.powf %arg_f64, %arg_f64 : f64 - // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_powf(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.powf % arg_f16, + % + arg_f16 : f16 + // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : + // (f16, f16) -> f16 + % + result32 = math.powf % arg_f32, + % + arg_f32 : f32 + // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : + // (f32, f32) -> f32 + % + result64 = math.powf % arg_f64, + % + arg_f64 : f64 + // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : + // (f64, f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -377,14 +669,28 @@ module @test_module { // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64 // CHECK-LABEL: func @math_rsqrt - func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.rsqrt %arg_f16 : f16 - // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.rsqrt %arg_f32 : f32 - // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.rsqrt %arg_f64 : f64 - // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_rsqrt(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.rsqrt % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.rsqrt % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.rsqrt % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -395,14 +701,28 @@ module @test_module { // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 // CHECK-LABEL: func @math_sin - func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.sin %arg_f16 : f16 - // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.sin %arg_f32 : f32 - // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.sin %arg_f64 : f64 - // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_sin(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.sin % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.sin % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.sin % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -413,14 +733,28 @@ module @test_module { // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64 // CHECK-LABEL: func @math_tanh - func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.tanh %arg_f16 : f16 - // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.tanh %arg_f32 : f32 - // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.tanh %arg_f64 : f64 - // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_tanh(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.tanh % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.tanh % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.tanh % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -431,14 +765,28 @@ module @test_module { // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64 // CHECK-LABEL: func @math_tan - func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.tan %arg_f16 : f16 - // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.tan %arg_f32 : f32 - // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.tan %arg_f64 : f64 - // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_tan(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.tan % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.tan % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.tan % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -449,14 +797,28 @@ module @test_module { // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32 // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64 // CHECK-LABEL: func @math_erf - func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.erf %arg_f16 : f16 - // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.erf %arg_f32 : f32 - // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.erf %arg_f64 : f64 - // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_erf(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.erf % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.erf % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.erf % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -467,14 +829,28 @@ module @test_module { // CHECK: llvm.func @__ocml_erfc_f32(f32) -> f32 // CHECK: llvm.func @__ocml_erfc_f64(f64) -> f64 // CHECK-LABEL: func @math_erfc - func.func @math_erfc(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { - %result16 = math.erfc %arg_f16 : f16 - // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16 - %result32 = math.erfc %arg_f32 : f32 - // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.erfc %arg_f64 : f64 - // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 + func.func @math_erfc(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64) + ->(f16, f32, f64) { + % result16 = math.erfc % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16 + % + result32 = math.erfc % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32 + % + result64 = math.erfc % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64 + func.return % + result16, + % result32, % result64 : f16, f32, f64 } } @@ -485,18 +861,36 @@ module @test_module { // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 // CHECK-LABEL: func @math_casting - func.func @math_casting(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64, %arg_bf16 : bf16) -> (f16, f32, f64, bf16) { - %resultf16 = math.sin %arg_f16 : f16 - // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 - %resultf32 = math.sin %arg_f32 : f32 - // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - %resultf64 = math.sin %arg_f64 : f64 - // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 - %resultbf16 = math.sin %arg_bf16 : bf16 - // CHECK: llvm.fpext %{{.*}} : bf16 to f32 - // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16 - func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16 + func.func @math_casting(% arg_f16 + : f16, % arg_f32 + : f32, % arg_f64 + : f64, % arg_bf16 + : bf16) + ->(f16, f32, f64, bf16) { + % resultf16 = math.sin % + arg_f16 + : f16 + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + % + resultf32 = math.sin % + arg_f32 + : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + % + resultf64 = math.sin % + arg_f64 + : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 + % + resultbf16 = math.sin % + arg_bf16 + : bf16 + // CHECK: llvm.fpext %{{.*}} : bf16 to f32 + // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16 + func.return % + resultf16, + % resultf32, % resultf64, % resultbf16 : f16, f32, f64, bf16 } } @@ -507,14 +901,22 @@ module @test_module { // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32 // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64 // CHECK-LABEL: func @math_fpowi - func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) { + func.func @math_fpowi(% arg0 + : f16, % arg1 + : f32, % arg2 + : f64, % arg3 + : i32) + ->(f16, f32, f64) { // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16 - %0 = math.fpowi %arg0, %arg3 : f16, i32 - // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32 - %1 = math.fpowi %arg1, %arg3 : f32, i32 - // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64 - %2 = math.fpowi %arg2, %arg3 : f64, i32 - return %0, %1, %2 : f16, f32, f64 + % 0 = math.fpowi % arg0, % arg3 : f16, + i32 + // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32 + % 1 = math.fpowi % arg1, + % arg3 : f32, + i32 + // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64 + % 2 = math.fpowi % arg2, + % arg3 : f64, i32 return % 0, % 1, % 2 : f16, f32, f64 } } @@ -523,13 +925,13 @@ module @test_module { // Math operation not inside function // Ensure it not crash -module { - "test.some_op_with_region"() ({ - ^bb0(%arg0: f64): - // CHECK: math.atan - %0 = math.atan %arg0 : f64 - "test.possible_terminator"() : () -> () - }) : () -> () +module{ + "test.some_op_with_region"()({ + ^bb0(% arg0:f64) : + // CHECK: math.atan + % 0 = math.atan % arg0:f64 "test.possible_terminator"() : ()->() + }) : () + ->() } // ----- @@ -537,12 +939,11 @@ module { module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_0d - func.func @math_sin_vector_0d(%arg : vector) -> vector { + func.func @math_sin_vector_0d(% arg : vector)->vector { // CHECK: llvm.extractelement {{.*}} : vector<1xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<1xf16> - %result = math.sin %arg : vector - func.return %result : vector + % result = math.sin % arg : vector func.return % result : vector } } @@ -551,7 +952,7 @@ module @test_module { module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_1d - func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> { + func.func @math_sin_vector_1d(% arg : vector<4xf16>)->vector<4xf16> { // CHECK: llvm.extractelement {{.*}} : vector<4xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<4xf16> @@ -564,8 +965,8 @@ module @test_module { // CHECK: llvm.extractelement {{.*}} : vector<4xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<4xf16> - %result = math.sin %arg : vector<4xf16> - func.return %result : vector<4xf16> + % result = + math.sin % arg : vector<4xf16> func.return % result : vector<4xf16> } } @@ -574,11 +975,11 @@ module @test_module { module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_2d - func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> { - // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.extractelement {{.*}} : vector<2xf16> - // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + func.func @math_sin_vector_2d(% arg : vector<2x2xf16>)->vector<2x2xf16> { + // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to + // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractvalue {{.*}} : + // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractelement {{.*}} : + // vector<2xf16> CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<2xf16> // CHECK: llvm.extractelement {{.*}} : vector<2xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 @@ -591,8 +992,42 @@ module @test_module { // CHECK: llvm.extractelement {{.*}} : vector<2xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<2xf16> - // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - %result = math.sin %arg : vector<2x2xf16> - func.return %result : vector<2x2xf16> + // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + % result = + math.sin % arg : vector<2x2xf16> func.return % result : vector<2x2xf16> } } + +// ----- + +// f16 clamp → rocdl.fmed3 on gfx9+ +func.func @clampf_f16(% x + : f16, % lo + : f16, % hi + : f16) + ->f16{ % r = math.clampf % x to[% lo, % hi] : f16 return % r : f16} + +// f32 clamp → rocdl.fmed3 on gfx9+ +func.func @clampf_f32(% x + : f32, % lo + : f32, % hi + : f32) + ->f32 { + % r = math.clampf % x to[% lo, % hi] : f32 return % r : f32 +} + +// POST9-LABEL: func.func @clampf_f16 +// POST9: rocdl.fmed3 {{.*}} : f16 +// POST9: return + +// POST9-LABEL: func.func @clampf_f32 +// POST9: rocdl.fmed3 {{.*}} : f32 +// POST9: return + +// PRE9-LABEL: func.func @clampf_f16 +// PRE9-NOT: rocdl.fmed3 +// PRE9: math.clampf {{.*}} : f16 + +// PRE9-LABEL: func.func @clampf_f32 +// PRE9-NOT: rocdl.fmed3 +// PRE9: math.clampf {{.*}} : f32 From 92bcb55d165dcf4407b045a38d98f01bd2a0c2bc Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 13 Oct 2025 13:00:35 -0700 Subject: [PATCH 2/9] Removed incorrect formatting Signed-off-by: Keshav Vinayak Jha --- .../Conversion/MathToROCDL/math-to-rocdl.mlir | 927 +++++------------- 1 file changed, 261 insertions(+), 666 deletions(-) diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 29851e2de5cb2..7244b0aac8e43 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -1,40 +1,19 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -// -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | -// FileCheck %s --check-prefix=PRE9 RUN: mlir-opt %s -allow-unregistered-dialect -// -split-input-file -// -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | -// FileCheck %s --check-prefix=POST9 +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9 +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9 module @test_module { // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 // CHECK-LABEL: func @arith_remf - func.func @arith_remf(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = arith.remf % arg_f16, - % - arg_f16 : f16 - // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : - // (f16, f16) -> f16 - % - result32 = arith.remf % arg_f32, - % - arg_f32 : f32 - // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : - // (f32, f32) -> f32 - % - result64 = arith.remf % arg_f64, - % - arg_f64 : f64 - // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : - // (f64, f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = arith.remf %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 + %result32 = arith.remf %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = arith.remf %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -45,28 +24,14 @@ module @test_module { // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64 // CHECK-LABEL: func @math_acos - func.func @math_acos(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.acos % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.acos % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.acos % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.acos %arg_f16 : f16 + // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.acos %arg_f32 : f32 + // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.acos %arg_f64 : f64 + // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -77,28 +42,14 @@ module @test_module { // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64 // CHECK-LABEL: func @math_acosh - func.func @math_acosh(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.acosh % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.acosh % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.acosh % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.acosh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.acosh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.acosh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -109,28 +60,14 @@ module @test_module { // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64 // CHECK-LABEL: func @math_asin - func.func @math_asin(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.asin % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.asin % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.asin % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.asin %arg_f16 : f16 + // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.asin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.asin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -141,28 +78,14 @@ module @test_module { // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64 // CHECK-LABEL: func @math_asinh - func.func @math_asinh(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.asinh % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.asinh % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.asinh % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.asinh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.asinh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.asinh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -173,28 +96,14 @@ module @test_module { // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64 // CHECK-LABEL: func @math_atan - func.func @math_atan(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.atan % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.atan % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.atan % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atan %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.atan %arg_f32 : f32 + // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.atan %arg_f64 : f64 + // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -205,28 +114,14 @@ module @test_module { // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64 // CHECK-LABEL: func @math_atanh - func.func @math_atanh(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.atanh % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.atanh % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.atanh % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atanh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.atanh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.atanh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -237,31 +132,14 @@ module @test_module { // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64 // CHECK-LABEL: func @math_atan2 - func.func @math_atan2(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.atan2 % arg_f16, - % - arg_f16 : f16 - // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : - // (f16, f16) -> f16 - % - result32 = math.atan2 % arg_f32, - % - arg_f32 : f32 - // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : - // (f32, f32) -> f32 - % - result64 = math.atan2 % arg_f64, - % - arg_f64 : f64 - // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) - // : (f64, f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atan2 %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 + %result32 = math.atan2 %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = math.atan2 %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -272,28 +150,14 @@ module @test_module { // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 // CHECK-LABEL: func @math_cbrt - func.func @math_cbrt(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.cbrt % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.cbrt % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.cbrt % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cbrt %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.cbrt %arg_f32 : f32 + // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.cbrt %arg_f64 : f64 + // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -304,28 +168,14 @@ module @test_module { // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32 // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64 // CHECK-LABEL: func @math_ceil - func.func @math_ceil(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.ceil % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.ceil % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.ceil % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.ceil %arg_f16 : f16 + // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.ceil %arg_f32 : f32 + // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.ceil %arg_f64 : f64 + // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -336,28 +186,14 @@ module @test_module { // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64 // CHECK-LABEL: func @math_cos - func.func @math_cos(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.cos % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.cos % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.cos % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cos %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.cos %arg_f32 : f32 + // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.cos %arg_f64 : f64 + // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -368,28 +204,14 @@ module @test_module { // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64 // CHECK-LABEL: func @math_cosh - func.func @math_cosh(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.cosh % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.cosh % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.cosh % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cosh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.cosh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.cosh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -400,28 +222,14 @@ module @test_module { // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64 // CHECK-LABEL: func @math_sinh - func.func @math_sinh(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.sinh % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.sinh % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.sinh % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.sinh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.sinh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.sinh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -431,18 +239,12 @@ module @test_module { // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16 // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @math_exp - func.func @math_exp(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) { - % result16 = - math.exp % - arg_f16 : f16 - // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16 - % - result64 = math.exp % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 - func.return % result16, - % result64 : f16, f64 + func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + %result16 = math.exp %arg_f16 : f16 + // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16 + %result64 = math.exp %arg_f64 : f64 + // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result64 : f16, f64 } } @@ -453,28 +255,14 @@ module @test_module { // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32 // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64 // CHECK-LABEL: func @math_exp2 - func.func @math_exp2(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.exp2 % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.exp2 % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.exp2 % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.exp2 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.exp2 %arg_f32 : f32 + // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.exp2 %arg_f64 : f64 + // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -485,28 +273,14 @@ module @test_module { // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32 // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64 // CHECK-LABEL: func @math_expm1 - func.func @math_expm1(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.expm1 % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.expm1 % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.expm1 % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.expm1 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.expm1 %arg_f32 : f32 + // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.expm1 %arg_f64 : f64 + // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -517,28 +291,14 @@ module @test_module { // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32 // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64 // CHECK-LABEL: func @math_floor - func.func @math_floor(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.floor % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.floor % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.floor % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.floor %arg_f16 : f16 + // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.floor %arg_f32 : f32 + // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.floor %arg_f64 : f64 + // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -548,18 +308,12 @@ module @test_module { // CHECK: llvm.func @__ocml_log_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @math_log - func.func @math_log(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) { - % result16 = - math.log % - arg_f16 : f16 - // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 - % - result64 = math.log % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 - func.return % result16, - % result64 : f16, f64 + func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + %result16 = math.log %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 + %result64 = math.log %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result64 : f16, f64 } } @@ -570,28 +324,14 @@ module @test_module { // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64 // CHECK-LABEL: func @math_log10 - func.func @math_log10(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.log10 % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.log10 % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.log10 % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log10 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.log10 %arg_f32 : f32 + // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.log10 %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -602,28 +342,14 @@ module @test_module { // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64 // CHECK-LABEL: func @math_log1p - func.func @math_log1p(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.log1p % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.log1p % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.log1p % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log1p %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.log1p %arg_f32 : f32 + // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.log1p %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -634,31 +360,14 @@ module @test_module { // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64 // CHECK-LABEL: func @math_powf - func.func @math_powf(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.powf % arg_f16, - % - arg_f16 : f16 - // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : - // (f16, f16) -> f16 - % - result32 = math.powf % arg_f32, - % - arg_f32 : f32 - // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : - // (f32, f32) -> f32 - % - result64 = math.powf % arg_f64, - % - arg_f64 : f64 - // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : - // (f64, f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.powf %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 + %result32 = math.powf %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = math.powf %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -669,28 +378,14 @@ module @test_module { // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64 // CHECK-LABEL: func @math_rsqrt - func.func @math_rsqrt(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.rsqrt % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.rsqrt % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.rsqrt % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.rsqrt %arg_f16 : f16 + // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.rsqrt %arg_f32 : f32 + // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.rsqrt %arg_f64 : f64 + // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -701,28 +396,14 @@ module @test_module { // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 // CHECK-LABEL: func @math_sin - func.func @math_sin(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.sin % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.sin % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.sin % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.sin %arg_f16 : f16 + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.sin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.sin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -733,28 +414,14 @@ module @test_module { // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64 // CHECK-LABEL: func @math_tanh - func.func @math_tanh(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.tanh % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.tanh % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.tanh % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.tanh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.tanh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.tanh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -765,28 +432,14 @@ module @test_module { // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64 // CHECK-LABEL: func @math_tan - func.func @math_tan(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.tan % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.tan % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.tan % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.tan %arg_f16 : f16 + // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.tan %arg_f32 : f32 + // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.tan %arg_f64 : f64 + // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -797,28 +450,14 @@ module @test_module { // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32 // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64 // CHECK-LABEL: func @math_erf - func.func @math_erf(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.erf % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.erf % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.erf % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.erf %arg_f16 : f16 + // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.erf %arg_f32 : f32 + // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.erf %arg_f64 : f64 + // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -829,28 +468,14 @@ module @test_module { // CHECK: llvm.func @__ocml_erfc_f32(f32) -> f32 // CHECK: llvm.func @__ocml_erfc_f64(f64) -> f64 // CHECK-LABEL: func @math_erfc - func.func @math_erfc(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64) - ->(f16, f32, f64) { - % result16 = math.erfc % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16 - % - result32 = math.erfc % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32 - % - result64 = math.erfc % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64 - func.return % - result16, - % result32, % result64 : f16, f32, f64 + func.func @math_erfc(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.erfc %arg_f16 : f16 + // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16 + %result32 = math.erfc %arg_f32 : f32 + // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.erfc %arg_f64 : f64 + // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -861,36 +486,18 @@ module @test_module { // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 // CHECK-LABEL: func @math_casting - func.func @math_casting(% arg_f16 - : f16, % arg_f32 - : f32, % arg_f64 - : f64, % arg_bf16 - : bf16) - ->(f16, f32, f64, bf16) { - % resultf16 = math.sin % - arg_f16 - : f16 - // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 - % - resultf32 = math.sin % - arg_f32 - : f32 - // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - % - resultf64 = math.sin % - arg_f64 - : f64 - // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 - % - resultbf16 = math.sin % - arg_bf16 - : bf16 - // CHECK: llvm.fpext %{{.*}} : bf16 to f32 - // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16 - func.return % - resultf16, - % resultf32, % resultf64, % resultbf16 : f16, f32, f64, bf16 + func.func @math_casting(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64, %arg_bf16 : bf16) -> (f16, f32, f64, bf16) { + %resultf16 = math.sin %arg_f16 : f16 + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + %resultf32 = math.sin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + %resultf64 = math.sin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 + %resultbf16 = math.sin %arg_bf16 : bf16 + // CHECK: llvm.fpext %{{.*}} : bf16 to f32 + // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16 + func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16 } } @@ -901,22 +508,14 @@ module @test_module { // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32 // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64 // CHECK-LABEL: func @math_fpowi - func.func @math_fpowi(% arg0 - : f16, % arg1 - : f32, % arg2 - : f64, % arg3 - : i32) - ->(f16, f32, f64) { + func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) { // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16 - % 0 = math.fpowi % arg0, % arg3 : f16, - i32 - // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32 - % 1 = math.fpowi % arg1, - % arg3 : f32, - i32 - // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64 - % 2 = math.fpowi % arg2, - % arg3 : f64, i32 return % 0, % 1, % 2 : f16, f32, f64 + %0 = math.fpowi %arg0, %arg3 : f16, i32 + // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32 + %1 = math.fpowi %arg1, %arg3 : f32, i32 + // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64 + %2 = math.fpowi %arg2, %arg3 : f64, i32 + return %0, %1, %2 : f16, f32, f64 } } @@ -925,13 +524,13 @@ module @test_module { // Math operation not inside function // Ensure it not crash -module{ - "test.some_op_with_region"()({ - ^bb0(% arg0:f64) : - // CHECK: math.atan - % 0 = math.atan % arg0:f64 "test.possible_terminator"() : ()->() - }) : () - ->() +module { + "test.some_op_with_region"() ({ + ^bb0(%arg0: f64): + // CHECK: math.atan + %0 = math.atan %arg0 : f64 + "test.possible_terminator"() : () -> () + }) : () -> () } // ----- @@ -939,11 +538,12 @@ module{ module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_0d - func.func @math_sin_vector_0d(% arg : vector)->vector { + func.func @math_sin_vector_0d(%arg : vector) -> vector { // CHECK: llvm.extractelement {{.*}} : vector<1xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<1xf16> - % result = math.sin % arg : vector func.return % result : vector + %result = math.sin %arg : vector + func.return %result : vector } } @@ -952,7 +552,7 @@ module @test_module { module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_1d - func.func @math_sin_vector_1d(% arg : vector<4xf16>)->vector<4xf16> { + func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> { // CHECK: llvm.extractelement {{.*}} : vector<4xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<4xf16> @@ -965,8 +565,8 @@ module @test_module { // CHECK: llvm.extractelement {{.*}} : vector<4xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<4xf16> - % result = - math.sin % arg : vector<4xf16> func.return % result : vector<4xf16> + %result = math.sin %arg : vector<4xf16> + func.return %result : vector<4xf16> } } @@ -975,11 +575,11 @@ module @test_module { module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_2d - func.func @math_sin_vector_2d(% arg : vector<2x2xf16>)->vector<2x2xf16> { - // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to - // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractvalue {{.*}} : - // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractelement {{.*}} : - // vector<2xf16> CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractelement {{.*}} : vector<2xf16> + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<2xf16> // CHECK: llvm.extractelement {{.*}} : vector<2xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 @@ -992,28 +592,24 @@ module @test_module { // CHECK: llvm.extractelement {{.*}} : vector<2xf16> // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 // CHECK: llvm.insertelement {{.*}} : vector<2xf16> - // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - % result = - math.sin % arg : vector<2x2xf16> func.return % result : vector<2x2xf16> + // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + %result = math.sin %arg : vector<2x2xf16> + func.return %result : vector<2x2xf16> } } // ----- // f16 clamp → rocdl.fmed3 on gfx9+ -func.func @clampf_f16(% x - : f16, % lo - : f16, % hi - : f16) - ->f16{ % r = math.clampf % x to[% lo, % hi] : f16 return % r : f16} +func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { + %r = math.clampf %x to [%lo, %hi] : f16 + return %r : f16 +} // f32 clamp → rocdl.fmed3 on gfx9+ -func.func @clampf_f32(% x - : f32, % lo - : f32, % hi - : f32) - ->f32 { - % r = math.clampf % x to[% lo, % hi] : f32 return % r : f32 +func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { + %r = math.clampf %x to [%lo, %hi] : f32 + return %r : f32 } // POST9-LABEL: func.func @clampf_f16 @@ -1030,4 +626,3 @@ func.func @clampf_f32(% x // PRE9-LABEL: func.func @clampf_f32 // PRE9-NOT: rocdl.fmed3 -// PRE9: math.clampf {{.*}} : f32 From 49b08f9a4ce206e9768b3f341c49f6377c21d116 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 13 Oct 2025 13:12:01 -0700 Subject: [PATCH 3/9] Corrected pass option Signed-off-by: Keshav Vinayak Jha --- mlir/include/mlir/Conversion/Passes.td | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index c3fd397e258ae..06bd82341acab 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -755,14 +755,6 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> { "func::FuncDialect", "vector::VectorDialect", ]; - let options = [ - Option<"chipset", "chipset", "std::string", - - - /*default=*/"\"gfx000\"", - "Chipset that these operations will run on"> - ]; - } //===----------------------------------------------------------------------===// @@ -793,6 +785,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [Option<"chipset", "chipset", "std::string", + /*default=*/"\"gfx000\"", + "Chipset that these operations will run on">]; } //===----------------------------------------------------------------------===// From 636ef8d9581229c916f69819a0fc172a648124bb Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 14 Oct 2025 06:50:42 -0700 Subject: [PATCH 4/9] Addressed Comments by Krzysztof: 1. Added lit test for 1D and 2D vectors 2. Added unrolling support for ND inputs Signed-off-by: Keshav Vinayak Jha --- .../Conversion/MathToROCDL/MathToROCDL.cpp | 18 ++++++ .../Conversion/MathToROCDL/math-to-rocdl.mlir | 57 +++++++++++++++---- 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index ceb3d22c6bd59..d8e3c34399ad4 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -59,10 +60,27 @@ struct ClampFOpConversion final op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) + "): V_MED_F16 / V_MED3_F32 not supported.")); } + auto resultType = getTypeConverter()->convertType(op.getType()); + // Handle multi-dimensional vectors (converted to LLVM arrays) + if (auto arrayType = dyn_cast(resultType)) { + // Handle multi-dimensional vectors (converted to LLVM arrays) + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + typename math::ClampFOp::Adaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, adaptor.getValue(), + adaptor.getMin(), adaptor.getMax()); + }, + rewriter); + } + + // Handle 1D vectors and scalars directly rewriter.replaceOpWithNewOp(op, op.getType(), op.getValue(), op.getMin(), op.getMax()); return success(); } + amdgpu::Chipset chipset; }; diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 7244b0aac8e43..55d48fa0d27f1 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -601,28 +601,63 @@ module @test_module { // ----- // f16 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_f16 func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { %r = math.clampf %x to [%lo, %hi] : f16 return %r : f16 + // POST9: rocdl.fmed3 {{.*}} : f16 + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : f16 } // f32 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_f32 func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { %r = math.clampf %x to [%lo, %hi] : f32 return %r : f32 + // POST9: rocdl.fmed3 {{.*}} : f32 + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : f32 } -// POST9-LABEL: func.func @clampf_f16 -// POST9: rocdl.fmed3 {{.*}} : f16 -// POST9: return +// ----- + +// Vector f16 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_vector_f16 +func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> { + %r = math.clampf %x to [%lo, %hi] : vector<2xf16> + return %r : vector<2xf16> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2xf16> +} + +// ----- -// POST9-LABEL: func.func @clampf_f32 -// POST9: rocdl.fmed3 {{.*}} : f32 -// POST9: return +// Vector f32 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_vector_f32 +func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> { + %r = math.clampf %x to [%lo, %hi] : vector<2xf32> + return %r : vector<2xf32> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf32> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2xf32> +} -// PRE9-LABEL: func.func @clampf_f16 -// PRE9-NOT: rocdl.fmed3 -// PRE9: math.clampf {{.*}} : f16 +// ----- -// PRE9-LABEL: func.func @clampf_f32 -// PRE9-NOT: rocdl.fmed3 +// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors) +// CHECK-LABEL: func.func @clampf_vector_2d_f16 +func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> { + %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16> + return %r : vector<2x2xf16> + // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>> + // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2x2xf16> +} From 767c0aca77c01c7b15dc5750335d2284514d1c70 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 14 Oct 2025 07:07:22 -0700 Subject: [PATCH 5/9] Set chipset default value to empty Signed-off-by: Keshav Vinayak Jha --- mlir/include/mlir/Conversion/Passes.td | 6 +++++- .../lib/Conversion/MathToROCDL/MathToROCDL.cpp | 18 +++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 06bd82341acab..78a6df3ad8755 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { let summary = "Convert Math dialect to ROCDL library calls"; let description = [{ This pass converts supported Math ops to ROCDL library calls. + + The chipset option specifies the target AMDGPU architecture. If the chipset + is empty, none of the chipset-dependent patterns are added and the pass + will not attempt to parse the chipset. }]; let dependentDialects = [ "arith::ArithDialect", @@ -786,7 +790,7 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "vector::VectorDialect", ]; let options = [Option<"chipset", "chipset", "std::string", - /*default=*/"\"gfx000\"", + /*default=*/"\"\"", "Chipset that these operations will run on">]; } diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index d8e3c34399ad4..aef768e225d67 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -84,9 +84,9 @@ struct ClampFOpConversion final amdgpu::Chipset chipset; }; -static void addChipsetDependentPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, - amdgpu::Chipset chipset) { +void addChipsetDependentPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + amdgpu::Chipset chipset) { patterns.add(converter, chipset); } @@ -183,12 +183,20 @@ struct ConvertMathToROCDLPass final void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); MLIRContext *ctx = m.getContext(); - FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset); + + // Only populate chipset-dependent patterns if chipset is specified + if (!chipset.empty()) { + FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); + if (failed(maybeChipset)) { + return signalPassFailure(); + } + populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset); + } + ConversionTarget target(getContext()); target .addLegalDialect Date: Tue, 14 Oct 2025 07:57:01 -0700 Subject: [PATCH 6/9] Pattern should only apply to f16/f32 types; added reject lit for bf16 Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 16 ++++++++++++++-- .../Conversion/MathToROCDL/math-to-rocdl.mlir | 10 ++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index aef768e225d67..38704157ba565 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -54,13 +54,25 @@ struct ClampFOpConversion final LogicalResult matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // V_MED3_F16/F32 only exists in gfx9+ artchitectures + // Only f16 and f32 types are supported by fmed3 + Type opTy = op.getType(); + auto resultType = getTypeConverter()->convertType(opTy); + + if (auto vectorType = dyn_cast(opTy)) { + opTy = vectorType.getElementType(); + } + + if (!opTy.isF16() && !opTy.isF32()) { + return rewriter.notifyMatchFailure( + op, "fmed3 only supports f16 and f32 types"); + } + + // V_MED3_F16/F32 only exists in gfx9+ architectures if (chipset.majorVersion < 9) { return rewriter.notifyMatchFailure( op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) + "): V_MED_F16 / V_MED3_F32 not supported.")); } - auto resultType = getTypeConverter()->convertType(op.getType()); // Handle multi-dimensional vectors (converted to LLVM arrays) if (auto arrayType = dyn_cast(resultType)) { // Handle multi-dimensional vectors (converted to LLVM arrays) diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 55d48fa0d27f1..959230ae6cd49 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -661,3 +661,13 @@ func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: // PRE9-NOT: rocdl.fmed3 // PRE9: math.clampf {{.*}} : vector<2x2xf16> } + +// ----- +// CHECK-LABEL: func.func @clampf_bf16 +func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 { + %r = math.clampf %x to [%lo, %hi] : bf16 + return %r : bf16 + // CHECK: math.clampf {{.*}} : bf16 + // CHECK-NOT: rocdl.fmed3 +} + From f25ec273391ffbb791fc95e19e01b6882203bbd2 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 14 Oct 2025 07:57:39 -0700 Subject: [PATCH 7/9] Formatting lit test Signed-off-by: Keshav Vinayak Jha --- mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 959230ae6cd49..455f886839604 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -670,4 +670,3 @@ func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 { // CHECK: math.clampf {{.*}} : bf16 // CHECK-NOT: rocdl.fmed3 } - From 9b50b9d903f54030c0d82298062a845ffc742e4f Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 14 Oct 2025 08:15:02 -0700 Subject: [PATCH 8/9] Moved GFX9+ condition to within addChipsetDependentPatterns Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 38704157ba565..0fb670020b964 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -67,12 +67,6 @@ struct ClampFOpConversion final op, "fmed3 only supports f16 and f32 types"); } - // V_MED3_F16/F32 only exists in gfx9+ architectures - if (chipset.majorVersion < 9) { - return rewriter.notifyMatchFailure( - op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) + - "): V_MED_F16 / V_MED3_F32 not supported.")); - } // Handle multi-dimensional vectors (converted to LLVM arrays) if (auto arrayType = dyn_cast(resultType)) { // Handle multi-dimensional vectors (converted to LLVM arrays) @@ -100,7 +94,10 @@ void addChipsetDependentPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset) { - patterns.add(converter, chipset); + // V_MED3_F16/F32 only exists in gfx9+ architectures + if (chipset.majorVersion >= 9) { + patterns.add(converter, chipset); + } } void mlir::populateMathToROCDLConversionPatterns( From 61af07cd931011944101f840087daa01bf8d2020 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 14 Oct 2025 10:05:47 -0700 Subject: [PATCH 9/9] Added valid default value for chipset to pass Signed-off-by: Keshav Vinayak Jha --- mlir/include/mlir/Conversion/Passes.td | 2 +- mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 33936da0190cc..a2eb335faac6c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -790,7 +790,7 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "vector::VectorDialect", ]; let options = [Option<"chipset", "chipset", "std::string", - /*default=*/"\"\"", + /*default=*/"\"gfx000\"", "Chipset that these operations will run on">]; } diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 0fb670020b964..4ba7eab64a785 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -62,7 +62,7 @@ struct ClampFOpConversion final opTy = vectorType.getElementType(); } - if (!opTy.isF16() && !opTy.isF32()) { + if (!isa(opTy)) { return rewriter.notifyMatchFailure( op, "fmed3 only supports f16 and f32 types"); } @@ -74,9 +74,9 @@ struct ClampFOpConversion final op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename math::ClampFOp::Adaptor adaptor(operands); - return rewriter.create( - op.getLoc(), llvm1DVectorTy, adaptor.getValue(), - adaptor.getMin(), adaptor.getMax()); + return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getValue(), adaptor.getMin(), + adaptor.getMax()); }, rewriter); } @@ -90,9 +90,9 @@ struct ClampFOpConversion final amdgpu::Chipset chipset; }; -void addChipsetDependentPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, - amdgpu::Chipset chipset) { +static void addChipsetDependentPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + amdgpu::Chipset chipset) { // V_MED3_F16/F32 only exists in gfx9+ architectures if (chipset.majorVersion >= 9) {