diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 264fb4966d39..5173e7f82c4b 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -106,6 +106,14 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, // Returns the squeezed tensor or failure. FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, Value input, int64_t dim); + +// Float 16 limits +constexpr float Float16Max = 65504.0f; +constexpr float Float16Lowest = -65504.0f; + +// BFloat 16 limits +constexpr float BFloat16Max = 3.38953139e38f; +constexpr float BFloat16Lowest = -3.38953139e38f; } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 8f89567df6f7..0bc93f711ad6 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -871,8 +871,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto outElemTy = outTy.getElementType(); - if (!selfTy) { + if (!selfTy || !outTy) { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); } @@ -883,12 +886,27 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point datatype legalization currently supported"); } + FloatAttr minFloatAttr, maxFloatAttr; + if (outElemTy.isF16()) { + minFloatAttr = rewriter.getF16FloatAttr(0.0f); + maxFloatAttr = rewriter.getF16FloatAttr(Float16Max); + } else if (outElemTy.isBF16()) { + minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 0.0f); + maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), BFloat16Max); + } else if (outElemTy.isF32()) { + minFloatAttr = rewriter.getF32FloatAttr(0.0f); + maxFloatAttr = rewriter.getF32FloatAttr(std::numeric_limits::max()); + } else if (outElemTy.isF64()) { + minFloatAttr = rewriter.getF64FloatAttr(0.0f); + maxFloatAttr = rewriter.getF64FloatAttr(std::numeric_limits::max()); + } else { + return rewriter.notifyMatchFailure(op, "Unsupported floating-point type"); + } + // Maps to tosa.clamp // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max()), + op, outTy, self, minFloatAttr, maxFloatAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); return success(); } @@ -5186,10 +5204,30 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); } else { - FloatAttr minFloatAttr = rewriter.getF32FloatAttr( - isMinNotNone ? minFloat : std::numeric_limits::lowest()); - FloatAttr maxFloatAttr = rewriter.getF32FloatAttr( - isMaxNotNone ? maxFloat : std::numeric_limits::max()); + FloatAttr minFloatAttr, maxFloatAttr; + if (outElemTy.isF16()) { + minFloatAttr = + rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest); + maxFloatAttr = + rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max); + } else if (outElemTy.isBF16()) { + minFloatAttr = rewriter.getFloatAttr( + rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest); + maxFloatAttr = rewriter.getFloatAttr( + rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max); + } else if (outElemTy.isF32()) { + minFloatAttr = rewriter.getF32FloatAttr( + isMinNotNone ? minFloat : std::numeric_limits::lowest()); + maxFloatAttr = rewriter.getF32FloatAttr( + isMaxNotNone ? maxFloat : std::numeric_limits::max()); + } else if (outElemTy.isF64()) { + minFloatAttr = rewriter.getF64FloatAttr( + isMinNotNone ? minFloat : std::numeric_limits::lowest()); + maxFloatAttr = rewriter.getF64FloatAttr( + isMaxNotNone ? maxFloat : std::numeric_limits::max()); + } else { + return rewriter.notifyMatchFailure(op, "Unsupported floating-point type"); + } rewriter.replaceOpWithNewOp( op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr, @@ -8547,14 +8585,29 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto zi = self; + FloatAttr minFloatAttr, maxFloatAttr; + if (resultElemTy.isF16()) { + minFloatAttr = rewriter.getF16FloatAttr(eps); + maxFloatAttr = rewriter.getF16FloatAttr(1 - eps); + } else if (resultElemTy.isBF16()) { + minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), eps); + maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 1 - eps); + } else if (resultElemTy.isF32()) { + minFloatAttr = rewriter.getF32FloatAttr(eps); + maxFloatAttr = rewriter.getF32FloatAttr(1 - eps); + } else if (resultElemTy.isF64()) { + minFloatAttr = rewriter.getF64FloatAttr(eps); + maxFloatAttr = rewriter.getF64FloatAttr(1 - eps); + } else { + return rewriter.notifyMatchFailure(op, "Unsupported floating-point type"); + } + // Clamp input to [eps, 1 - eps] when eps is not None // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp if (!isEpsNone) { zi = rewriter .create( - op->getLoc(), resultType, self, - rewriter.getF32FloatAttr(static_cast(eps)), - rewriter.getF32FloatAttr(static_cast(1 - eps)), + op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) .getResult(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 35c498e6ebf0..3ed071c4f150 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -530,6 +530,11 @@ "ReflectionPad3dModuleBack_basic", # RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule "NativeGroupNormModule_basic", + # error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>' + "ElementwiseClampMaxModule_bfloat16", + "ElementwiseClampMinModule_bfloat16", + "ElementwiseClampModule_bfloat16", + "ElementwiseReluModule_bfloat16", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -988,6 +993,11 @@ "NativeGroupNormModule_basic", "AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic", "MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic", + # error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>' + "ElementwiseClampMaxModule_bfloat16", + "ElementwiseClampMinModule_bfloat16", + "ElementwiseClampModule_bfloat16", + "ElementwiseReluModule_bfloat16", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3392,6 +3402,11 @@ # RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small "AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic", "MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic", + # error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref' + "ElementwiseClampMaxModule_bfloat16", + "ElementwiseClampMinModule_bfloat16", + "ElementwiseClampModule_bfloat16", + "ElementwiseReluModule_bfloat16", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3958,6 +3973,11 @@ "ReplicationPad1dModule_3DInput_basic", "ReplicationPad3dModule_basic", "ReplicationPad3dModuleSingleIntPad_basic", + # error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>' + "ElementwiseClampMaxModule_bfloat16", + "ElementwiseClampMinModule_bfloat16", + "ElementwiseClampModule_bfloat16", + "ElementwiseReluModule_bfloat16", } ONNX_TOSA_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index b9cd4cc78c50..c00e48f39e88 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -834,6 +834,52 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseReluBFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bfloat16, True), + ] + ) + def forward(self, x): + return torch.relu(x) + + +@register_test_case(module_factory=lambda: ElementwiseReluBFloat16Module()) +def ElementwiseReluModule_bfloat16(module, tu: TestUtils): + module.forward(tu.rand(4, 2, low=-1).to(torch.bfloat16)) + + +# ============================================================================== + + +class ElementwiseReluFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float16, True), + ] + ) + def forward(self, x): + return torch.relu(x) + + +@register_test_case(module_factory=lambda: ElementwiseReluFloat16Module()) +def ElementwiseReluModule_float16(module, tu: TestUtils): + module.forward(tu.rand(4, 2, low=-1).to(torch.float16)) + + +# ============================================================================== + + class QuantizedReluInt8(torch.nn.Module): def __init__(self): super().__init__() @@ -1769,6 +1815,62 @@ def ElementwiseClampModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampBFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bfloat16, True), + ] + ) + def forward(self, x): + float_min = torch.clamp(x, min=-2.0) + int_min = torch.clamp(x, min=-3) + float_max = torch.clamp(x, max=2.0) + int_max = torch.clamp(x, max=3) + both = torch.clamp(x, min=-5, max=5) + return float_min, int_min, float_max, int_max, both + + +@register_test_case(module_factory=lambda: ElementwiseClampBFloat16Module()) +def ElementwiseClampModule_bfloat16(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.bfloat16)) + + +# ============================================================================== + + +class ElementwiseClampFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float16, True), + ] + ) + def forward(self, x): + float_min = torch.clamp(x, min=-2.0) + int_min = torch.clamp(x, min=-3) + float_max = torch.clamp(x, max=2.0) + int_max = torch.clamp(x, max=3) + both = torch.clamp(x, min=-5, max=5) + return float_min, int_min, float_max, int_max, both + + +@register_test_case(module_factory=lambda: ElementwiseClampFloat16Module()) +def ElementwiseClampModule_float16(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.float16)) + + +# ============================================================================== + + class ElementwiseClampMinModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1795,6 +1897,58 @@ def ElementwiseClampMinModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampMinBFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bfloat16, True), + ] + ) + def forward(self, x): + float_min = torch.ops.aten.clamp_min(x, min=-2.0) + int_min = torch.ops.aten.clamp_min(x, min=2) + min = torch.ops.aten.clamp_min(x, min=11.0) + return float_min, int_min, min + + +@register_test_case(module_factory=lambda: ElementwiseClampMinBFloat16Module()) +def ElementwiseClampMinModule_bfloat16(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.bfloat16)) + + +# ============================================================================== + + +class ElementwiseClampMinFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float16, True), + ] + ) + def forward(self, x): + float_min = torch.ops.aten.clamp_min(x, min=-2.0) + int_min = torch.ops.aten.clamp_min(x, min=2) + min = torch.ops.aten.clamp_min(x, min=11.0) + return float_min, int_min, min + + +@register_test_case(module_factory=lambda: ElementwiseClampMinFloat16Module()) +def ElementwiseClampMinModule_float16(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.float16)) + + +# ============================================================================== + + class ElementwiseClampMaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1821,6 +1975,58 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampMaxBFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bfloat16, True), + ] + ) + def forward(self, x): + float_max = torch.ops.aten.clamp_max(x, max=2.0) + int_max = torch.ops.aten.clamp_max(x, max=3) + max = torch.ops.aten.clamp_max(x, max=-11.0) + return float_max, int_max, max + + +@register_test_case(module_factory=lambda: ElementwiseClampMaxBFloat16Module()) +def ElementwiseClampMaxModule_bfloat16(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.bfloat16)) + + +# ============================================================================== + + +class ElementwiseClampMaxFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float16, True), + ] + ) + def forward(self, x): + float_max = torch.ops.aten.clamp_max(x, max=2.0) + int_max = torch.ops.aten.clamp_max(x, max=3) + max = torch.ops.aten.clamp_max(x, max=-11.0) + return float_max, int_max, max + + +@register_test_case(module_factory=lambda: ElementwiseClampMaxFloat16Module()) +def ElementwiseClampMaxModule_float16(module, tu: TestUtils): + module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.float16)) + + +# ============================================================================== + + class ElementwiseClampTensorFloatModule(torch.nn.Module): def __init__(self): super().__init__()