diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 76f5825025739..d759299cbf762 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -130,6 +130,10 @@ namespace arith { Value createProduct(OpBuilder &builder, Location loc, ArrayRef values); Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, Type resultType); + +// Map strings to float types. +std::optional parseFloatType(MLIRContext *ctx, StringRef name); + } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 2dd7f6431f03e..2974bb344ad96 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -56,11 +56,13 @@ void populateMathPolynomialApproximationPatterns( void populateUpliftToFMAPatterns(RewritePatternSet &patterns); namespace math { -void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter); -void populateLegalizeToF32ConversionTarget(ConversionTarget &target, - TypeConverter &typeConverter); -void populateLegalizeToF32Patterns(RewritePatternSet &patterns, - TypeConverter &typeConverter); +void populateExtendToSupportedTypesTypeConverter( + TypeConverter &typeConverter, const SetVector &sourceTypes, + Type targetType); +void populateExtendToSupportedTypesConversionTarget( + ConversionTarget &target, TypeConverter &typeConverter); +void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); } // namespace math } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index e870e714bfda5..a84c89020d4f3 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> { let dependentDialects = ["math::MathDialect"]; } -def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> { +def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> { let summary = "Legalize floating-point math ops on low-precision floats"; let description = [{ On many targets, the math functions are not implemented for floating-point @@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> { This pass explicitly legalizes these math functions by inserting `arith.extf` and `arith.truncf` pairs around said op, which preserves - the original semantics while enabling lowering. + the original semantics while enabling lowering. The extra supported floating-point + types for the target are passed as arguments. Types f64 and f32 are implicitly + supported. As an exception, this pass does not legalize `math.fma`, because that is an operation frequently implemented at low precisions. }]; + let options = [ + ListOption<"extraTypeStrs", "extra-types", "std::string", + "MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">, + Option<"targetTypeStr", "target-type", "std::string", "\"f32\"", + "MLIR type to convert the unsupported source types to">, + ]; let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 0bf8c8942885e..b51444e884aae 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" @@ -49,30 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern { }; } // end namespace -/// Map strings to float types. This function is here because no one else needs -/// it yet, feel free to abstract it out. -static std::optional parseFloatType(MLIRContext *ctx, - StringRef name) { - Builder b(ctx); - return llvm::StringSwitch>(name) - .Case("f4E2M1FN", b.getFloat4E2M1FNType()) - .Case("f6E2M3FN", b.getFloat6E2M3FNType()) - .Case("f6E3M2FN", b.getFloat6E3M2FNType()) - .Case("f8E5M2", b.getFloat8E5M2Type()) - .Case("f8E4M3", b.getFloat8E4M3Type()) - .Case("f8E4M3FN", b.getFloat8E4M3FNType()) - .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType()) - .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType()) - .Case("f8E3M4", b.getFloat8E3M4Type()) - .Case("bf16", b.getBF16Type()) - .Case("f16", b.getF16Type()) - .Case("f32", b.getF32Type()) - .Case("f64", b.getF64Type()) - .Case("f80", b.getF80Type()) - .Case("f128", b.getF128Type()) - .Default(std::nullopt); -} - LogicalResult EmulateFloatPattern::match(Operation *op) const { if (getTypeConverter()->isLegal(op)) return failure(); @@ -156,7 +133,8 @@ void EmulateUnsupportedFloatsPass::runOnOperation() { SmallVector sourceTypes; Type targetType; - std::optional maybeTargetType = parseFloatType(ctx, targetTypeStr); + std::optional maybeTargetType = + arith::parseFloatType(ctx, targetTypeStr); if (!maybeTargetType) { emitError(UnknownLoc::get(ctx), "could not map target type '" + targetTypeStr + @@ -166,7 +144,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() { targetType = *maybeTargetType; for (StringRef sourceTypeStr : sourceTypeStrs) { std::optional maybeSourceType = - parseFloatType(ctx, sourceTypeStr); + arith::parseFloatType(ctx, sourceTypeStr); if (!maybeSourceType) { emitError(UnknownLoc::get(ctx), "could not map source type '" + sourceTypeStr + diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index e75db84b75e28..c0aa16cc0da40 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -357,4 +357,26 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); } +/// Map strings to float types. +std::optional parseFloatType(MLIRContext *ctx, StringRef name) { + Builder b(ctx); + return llvm::StringSwitch>(name) + .Case("f4E2M1FN", b.getFloat4E2M1FNType()) + .Case("f6E2M3FN", b.getFloat6E2M3FNType()) + .Case("f6E3M2FN", b.getFloat6E3M2FNType()) + .Case("f8E5M2", b.getFloat8E5M2Type()) + .Case("f8E4M3", b.getFloat8E4M3Type()) + .Case("f8E4M3FN", b.getFloat8E4M3FNType()) + .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType()) + .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType()) + .Case("f8E3M4", b.getFloat8E3M4Type()) + .Case("bf16", b.getBF16Type()) + .Case("f16", b.getF16Type()) + .Case("f32", b.getF32Type()) + .Case("f64", b.getF64Type()) + .Case("f80", b.getF80Type()) + .Case("f128", b.getF128Type()) + .Default(std::nullopt); +} + } // namespace mlir::arith diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index 2a5b4fbcb5271..e1c0c2410c126 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp ExpandPatterns.cpp - LegalizeToF32.cpp + ExtendToSupportedTypes.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp new file mode 100644 index 0000000000000..1a9eafec9fdd5 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp @@ -0,0 +1,164 @@ +//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats +//----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements legalizing math operations on unsupported floating-point +// types through arith.extf and arith.truncf. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::math { +#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + +using namespace mlir; + +namespace { +struct ExtendToSupportedTypesRewritePattern final : ConversionPattern { + ExtendToSupportedTypesRewritePattern(TypeConverter &converter, + MLIRContext *context) + : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +struct ExtendToSupportedTypesPass + : mlir::math::impl::MathExtendToSupportedTypesBase< + ExtendToSupportedTypesPass> { + using math::impl::MathExtendToSupportedTypesBase< + ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase; + + void runOnOperation() override; +}; +} // namespace + +void mlir::math::populateExtendToSupportedTypesTypeConverter( + TypeConverter &typeConverter, const SetVector &sourceTypes, + Type targetType) { + + typeConverter.addConversion( + [](Type type) -> std::optional { return type; }); + typeConverter.addConversion( + [&sourceTypes, targetType](FloatType type) -> std::optional { + if (!sourceTypes.contains(type)) + return targetType; + + return std::nullopt; + }); + typeConverter.addConversion( + [&sourceTypes, targetType](ShapedType type) -> std::optional { + if (auto elemTy = dyn_cast(type.getElementType())) + if (!sourceTypes.contains(elemTy)) + return type.clone(targetType); + + return std::nullopt; + }); + typeConverter.addTargetMaterialization( + [](OpBuilder &b, Type target, ValueRange input, Location loc) { + auto extFOp = b.create(loc, target, input); + extFOp.setFastmath(arith::FastMathFlags::contract); + return extFOp; + }); +} + +void mlir::math::populateExtendToSupportedTypesConversionTarget( + ConversionTarget &target, TypeConverter &typeConverter) { + target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { + if (isa(op->getDialect())) + return typeConverter.isLegal(op); + return true; + }); + target.addLegalOp(); + target.addLegalOp(); +} + +LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + const TypeConverter *converter = getTypeConverter(); + FailureOr legalized = + convertOpResultTypes(op, operands, *converter, rewriter); + if (failed(legalized)) + return failure(); + + SmallVector results = (*legalized)->getResults(); + for (auto [result, newType, origType] : llvm::zip_equal( + results, (*legalized)->getResultTypes(), op->getResultTypes())) { + if (newType != origType) { + auto truncFOp = rewriter.create(loc, origType, result); + truncFOp.setFastmath(arith::FastMathFlags::contract); + result = truncFOp.getResult(); + } + } + rewriter.replaceOp(op, results); + return success(); +} + +void mlir::math::populateExtendToSupportedTypesPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter) { + patterns.add(typeConverter, + patterns.getContext()); +} + +void ExtendToSupportedTypesPass::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *ctx = &getContext(); + + // Parse target type + std::optional maybeTargetType = + arith::parseFloatType(ctx, targetTypeStr); + if (!maybeTargetType.has_value()) { + emitError(UnknownLoc::get(ctx), "could not map target type '" + + targetTypeStr + + "' to a known floating-point type"); + return signalPassFailure(); + } + Type targetType = maybeTargetType.value(); + + // Parse source types + llvm::SetVector sourceTypes; + for (const auto &extraTypeStr : extraTypeStrs) { + std::optional maybeExtraType = + arith::parseFloatType(ctx, extraTypeStr); + if (!maybeExtraType.has_value()) { + emitError(UnknownLoc::get(ctx), "could not map source type '" + + extraTypeStr + + "' to a known floating-point type"); + return signalPassFailure(); + } + sourceTypes.insert(maybeExtraType.value()); + } + // f64 and f32 are implicitly supported + Builder b(ctx); + sourceTypes.insert(b.getF64Type()); + sourceTypes.insert(b.getF32Type()); + + TypeConverter typeConverter; + math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes, + targetType); + ConversionTarget target(*ctx); + math::populateExtendToSupportedTypesConversionTarget(target, typeConverter); + RewritePatternSet patterns(ctx); + math::populateExtendToSupportedTypesPatterns(patterns, typeConverter); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp deleted file mode 100644 index 2e60fe455dcad..0000000000000 --- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp +++ /dev/null @@ -1,118 +0,0 @@ -//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements legalizing math operations on small floating-point -// types through arith.extf and arith.truncf. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/STLExtras.h" - -namespace mlir::math { -#define GEN_PASS_DEF_MATHLEGALIZETOF32 -#include "mlir/Dialect/Math/Transforms/Passes.h.inc" -} // namespace mlir::math - -using namespace mlir; -namespace { -struct LegalizeToF32RewritePattern final : ConversionPattern { - LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context) - : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -struct LegalizeToF32Pass final - : mlir::math::impl::MathLegalizeToF32Base { - void runOnOperation() override; -}; -} // namespace - -void mlir::math::populateLegalizeToF32TypeConverter( - TypeConverter &typeConverter) { - typeConverter.addConversion( - [](Type type) -> std::optional { return type; }); - typeConverter.addConversion([](FloatType type) -> std::optional { - if (type.getWidth() < 32) - return Float32Type::get(type.getContext()); - return std::nullopt; - }); - typeConverter.addConversion([](ShapedType type) -> std::optional { - if (auto elemTy = dyn_cast(type.getElementType())) - return type.clone(Float32Type::get(type.getContext())); - return std::nullopt; - }); - typeConverter.addTargetMaterialization( - [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); - extFOp.setFastmath(arith::FastMathFlags::contract); - return extFOp; - }); -} - -void mlir::math::populateLegalizeToF32ConversionTarget( - ConversionTarget &target, TypeConverter &typeConverter) { - target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { - if (isa(op->getDialect())) - return typeConverter.isLegal(op); - return true; - }); - target.addLegalOp(); - target.addLegalOp(); -} - -LogicalResult LegalizeToF32RewritePattern::matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); - const TypeConverter *converter = getTypeConverter(); - FailureOr legalized = - convertOpResultTypes(op, operands, *converter, rewriter); - if (failed(legalized)) - return failure(); - - SmallVector results = (*legalized)->getResults(); - for (auto [result, newType, origType] : llvm::zip_equal( - results, (*legalized)->getResultTypes(), op->getResultTypes())) { - if (newType != origType) { - auto truncFOp = rewriter.create(loc, origType, result); - truncFOp.setFastmath(arith::FastMathFlags::contract); - result = truncFOp.getResult(); - } - } - rewriter.replaceOp(op, results); - return success(); -} - -void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns, - TypeConverter &typeConverter) { - patterns.add(typeConverter, - patterns.getContext()); -} - -void LegalizeToF32Pass::runOnOperation() { - Operation *op = getOperation(); - MLIRContext &ctx = getContext(); - - TypeConverter typeConverter; - math::populateLegalizeToF32TypeConverter(typeConverter); - ConversionTarget target(ctx); - math::populateLegalizeToF32ConversionTarget(target, typeConverter); - RewritePatternSet patterns(&ctx); - math::populateLegalizeToF32Patterns(patterns, typeConverter); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) - return signalPassFailure(); -} diff --git a/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir new file mode 100644 index 0000000000000..3674a91ef425f --- /dev/null +++ b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir @@ -0,0 +1,146 @@ +// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="extra-types=f16 target-type=f32" | FileCheck %s + +// CHECK-LABEL: @sin_f8E5M2 +// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2) +func.func @sin_f8E5M2(%arg0: f8E5M2) -> f8E5M2 { + // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] + // CHECK: [[SIN:%.+]] = math.sin [[EXTF]] + // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] + // CHECK: return [[TRUNCF]] : f8E5M2 + %0 = math.sin %arg0 : f8E5M2 + return %0 : f8E5M2 +} + +// CHECK-LABEL: @sin +// CHECK-SAME: ([[ARG0:%.+]]: f16) +func.func @sin(%arg0: f16) -> f16 { + // CHECK16: [[SIN:%.+]] = math.sin [[ARG0]] : f16 + // CHECK16: return [[SIN]] : f16 + %0 = math.sin %arg0 : f16 + return %0 : f16 +} + +// CHECK-LABEL: @fpowi_f8E5M2 +// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2, [[ARG1:%.+]]: i32) +func.func @fpowi_f8E5M2(%arg0: f8E5M2, %arg1: i32) -> f8E5M2 { + // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] + // CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]] + // CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]] + // CHECK: return [[TRUNCF]] : f8E5M2 + %0 = math.fpowi %arg0, %arg1 : f8E5M2, i32 + return %0 : f8E5M2 +} + +// CHECK-LABEL: @fpowi +// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32) +func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 { + // CHECK: [[FPOWI:%.+]] = math.fpowi [[ARG0]], [[ARG1]] + // CHECK: return [[FPOWI]] : f16 + %0 = math.fpowi %arg0, %arg1 : f16, i32 + return %0 : f16 +} + +// COM: Verify that the pass leaves `math.fma` untouched, since it is often +// COM: implemented on small data types. +// CHECK-LABEL: @fma +// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16) +// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]] +// CHECK: return [[FMA]] : f16 +func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 { + %0 = math.fma %arg0, %arg1, %arg2 : f16 + return %0 : f16 +} + +// CHECK-LABEL: @absf_f16 +// CHECK-SAME: ([[ARG0:%.+]]: f16) +// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]] +// CHECK: return [[ABSF]] : f16 +func.func @absf_f16(%arg0: f16) -> f16 { + %0 = math.absf %arg0 : f16 + return %0 : f16 +} + +// CHECK-LABEL: @absf_f32 +// CHECK-SAME: ([[ARG0:%.+]]: f32) +// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]] +// CHECK: return [[ABSF]] : f32 +func.func @absf_f32(%arg0: f32) -> f32 { + %0 = math.absf %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: @absf_f64 +// CHECK-SAME: ([[ARG0:%.+]]: f64) +// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]] +// CHECK: return [[ABSF]] : f64 +func.func @absf_f64(%arg0: f64) -> f64 { + %0 = math.absf %arg0 : f64 + return %0 : f64 +} + +// CHECK-LABEL: @sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[SIN:%.+]] = math.sin [[EXTF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<2xbf16> +func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> { + %0 = math.sin %arg0 : vector<2xbf16> + return %0 : vector<2xbf16> +} + +// CHECK-LABEL: @sin_vector_f16 +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xf16>) +// CHECK: [[SIN:%.+]] = math.sin [[ARG0]] +// CHECK: return [[SIN]] : vector<2xf16> +func.func @sin_vector_f16(%arg0: vector<2xf16>) -> vector<2xf16> { + %0 = math.sin %arg0 : vector<2xf16> + return %0 : vector<2xf16> +} + +// CHECK-LABEL: @fastmath +// CHECK: math.sin %{{.+}} fastmath +func.func @fastmath(%arg0: f16) -> f16 { + %0 = math.sin %arg0 fastmath : f16 + return %0 : f16 +} + +// CHECK-LABEL: @sequences_f8E5M2 +// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2) +// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] +// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]] +// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]] +// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF1]] : f8E5M2 +func.func @sequences_f8E5M2(%arg0: f8E5M2) -> f8E5M2 { + %0 = math.absf %arg0 : f8E5M2 + %1 = math.sin %0 : f8E5M2 + return %1 : f8E5M2 +} + +// CHECK-LABEL: @sequences +// CHECK-SAME: ([[ARG0:%.+]]: f16) +// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: return [[SIN]] : f16 +func.func @sequences(%arg0: f16) -> f16 { + %0 = math.absf %arg0 : f16 + %1 = math.sin %0 : f16 + return %1 : f16 +} + +// CHECK-LABEL: @promote_in_if_block +func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 { + // CHECK: [[EXTF0:%.+]] = arith.extf + // CHECK-NEXT: %[[RES:.*]] = scf.if + %0 = scf.if %arg2 -> bf16 { + %1 = math.absf %arg0 : bf16 + // CHECK: [[TRUNCF0:%.+]] = arith.truncf + scf.yield %1 : bf16 + } else { + scf.yield %arg1 : bf16 + } + return %0 : bf16 +} \ No newline at end of file diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/extend-to-supported-types.mlir similarity index 96% rename from mlir/test/Dialect/Math/legalize-to-f32.mlir rename to mlir/test/Dialect/Math/extend-to-supported-types.mlir index ebb0de9d2653e..ad7169d4cf4ae 100644 --- a/mlir/test/Dialect/Math/legalize-to-f32.mlir +++ b/mlir/test/Dialect/Math/extend-to-supported-types.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="target-type=f32" | FileCheck %s // CHECK-LABEL: @sin // CHECK-SAME: ([[ARG0:%.+]]: f16)