Skip to content

Commit 9378eab

Browse files
[TOSA] Add F16 and BF16 support for tosa.clamp
Signed-off-by: Justin Ngo <[email protected]>
1 parent 46c3888 commit 9378eab

File tree

4 files changed

+293
-11
lines changed

4 files changed

+293
-11
lines changed

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
106106
// Returns the squeezed tensor or failure.
107107
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
108108
Value input, int64_t dim);
109+
110+
// Float 16 limits
111+
constexpr float Float16Max = 65504.0f;
112+
constexpr float Float16Lowest = -65504.0f;
113+
114+
// BFloat 16 limits
115+
constexpr float BFloat16Max = 3.38953139e38f;
116+
constexpr float BFloat16Lowest = -3.38953139e38f;
109117
} // namespace Torch
110118
} // namespace torch
111119
} // namespace mlir

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
871871
ConversionPatternRewriter &rewriter) const {
872872
Value self = adaptor.getSelf();
873873
auto selfTy = cast<TensorType>(self.getType());
874+
auto outTy =
875+
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
876+
auto outElemTy = outTy.getElementType();
874877

875-
if (!selfTy) {
878+
if (!selfTy || !outTy) {
876879
return rewriter.notifyMatchFailure(op,
877880
"Only Tensor types supported in TOSA");
878881
}
@@ -883,12 +886,27 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
883886
op, "Only floating-point datatype legalization currently supported");
884887
}
885888

889+
FloatAttr minFloatAttr, maxFloatAttr;
890+
if (outElemTy.isF16()) {
891+
minFloatAttr = rewriter.getF16FloatAttr(0.0f);
892+
maxFloatAttr = rewriter.getF16FloatAttr(Float16Max);
893+
} else if (outElemTy.isBF16()) {
894+
minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 0.0f);
895+
maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), BFloat16Max);
896+
} else if (outElemTy.isF32()) {
897+
minFloatAttr = rewriter.getF32FloatAttr(0.0f);
898+
maxFloatAttr = rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
899+
} else if (outElemTy.isF64()) {
900+
minFloatAttr = rewriter.getF64FloatAttr(0.0f);
901+
maxFloatAttr = rewriter.getF64FloatAttr(std::numeric_limits<double>::max());
902+
} else {
903+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
904+
}
905+
886906
// Maps to tosa.clamp
887907
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888908
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
889-
op, getTypeConverter()->convertType(op.getType()), self,
890-
rewriter.getF32FloatAttr(0.0f),
891-
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()),
909+
op, outTy, self, minFloatAttr, maxFloatAttr,
892910
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
893911
return success();
894912
}
@@ -5186,10 +5204,30 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
51865204
op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr,
51875205
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
51885206
} else {
5189-
FloatAttr minFloatAttr = rewriter.getF32FloatAttr(
5190-
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5191-
FloatAttr maxFloatAttr = rewriter.getF32FloatAttr(
5192-
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5207+
FloatAttr minFloatAttr, maxFloatAttr;
5208+
if (outElemTy.isF16()) {
5209+
minFloatAttr =
5210+
rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest);
5211+
maxFloatAttr =
5212+
rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max);
5213+
} else if (outElemTy.isBF16()) {
5214+
minFloatAttr = rewriter.getFloatAttr(
5215+
rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest);
5216+
maxFloatAttr = rewriter.getFloatAttr(
5217+
rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max);
5218+
} else if (outElemTy.isF32()) {
5219+
minFloatAttr = rewriter.getF32FloatAttr(
5220+
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5221+
maxFloatAttr = rewriter.getF32FloatAttr(
5222+
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5223+
} else if (outElemTy.isF64()) {
5224+
minFloatAttr = rewriter.getF64FloatAttr(
5225+
isMinNotNone ? minFloat : std::numeric_limits<double>::lowest());
5226+
maxFloatAttr = rewriter.getF64FloatAttr(
5227+
isMaxNotNone ? maxFloat : std::numeric_limits<double>::max());
5228+
} else {
5229+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
5230+
}
51935231

51945232
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
51955233
op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr,
@@ -8547,14 +8585,29 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85478585

85488586
auto zi = self;
85498587

8588+
FloatAttr minFloatAttr, maxFloatAttr;
8589+
if (resultElemTy.isF16()) {
8590+
minFloatAttr = rewriter.getF16FloatAttr(eps);
8591+
maxFloatAttr = rewriter.getF16FloatAttr(1 - eps);
8592+
} else if (resultElemTy.isBF16()) {
8593+
minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), eps);
8594+
maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 1 - eps);
8595+
} else if (resultElemTy.isF32()) {
8596+
minFloatAttr = rewriter.getF32FloatAttr(eps);
8597+
maxFloatAttr = rewriter.getF32FloatAttr(1 - eps);
8598+
} else if (resultElemTy.isF64()) {
8599+
minFloatAttr = rewriter.getF64FloatAttr(eps);
8600+
maxFloatAttr = rewriter.getF64FloatAttr(1 - eps);
8601+
} else {
8602+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
8603+
}
8604+
85508605
// Clamp input to [eps, 1 - eps] when eps is not None
85518606
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
85528607
if (!isEpsNone) {
85538608
zi = rewriter
85548609
.create<tosa::ClampOp>(
8555-
op->getLoc(), resultType, self,
8556-
rewriter.getF32FloatAttr(static_cast<float>(eps)),
8557-
rewriter.getF32FloatAttr(static_cast<float>(1 - eps)),
8610+
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
85588611
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
85598612
.getResult();
85608613
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@
530530
"ReflectionPad3dModuleBack_basic",
531531
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
532532
"NativeGroupNormModule_basic",
533+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
534+
"ElementwiseClampMaxModule_bfloat16",
535+
"ElementwiseClampMinModule_bfloat16",
536+
"ElementwiseClampModule_bfloat16",
537+
"ElementwiseReluModule_bfloat16",
533538
}
534539

535540
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
@@ -3392,6 +3397,11 @@
33923397
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
33933398
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33943399
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3400+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<?x?xbf16>'
3401+
"ElementwiseClampMaxModule_bfloat16",
3402+
"ElementwiseClampMinModule_bfloat16",
3403+
"ElementwiseClampModule_bfloat16",
3404+
"ElementwiseReluModule_bfloat16",
33953405
}
33963406

33973407
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3958,6 +3968,11 @@
39583968
"ReplicationPad1dModule_3DInput_basic",
39593969
"ReplicationPad3dModule_basic",
39603970
"ReplicationPad3dModuleSingleIntPad_basic",
3971+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
3972+
"ElementwiseClampMaxModule_bfloat16",
3973+
"ElementwiseClampMinModule_bfloat16",
3974+
"ElementwiseClampModule_bfloat16",
3975+
"ElementwiseReluModule_bfloat16",
39613976
}
39623977

39633978
ONNX_TOSA_CRASHING_SET = {

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,52 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
834834
# ==============================================================================
835835

836836

837+
class ElementwiseReluBFloat16Module(torch.nn.Module):
838+
def __init__(self):
839+
super().__init__()
840+
841+
@export
842+
@annotate_args(
843+
[
844+
None,
845+
([-1, -1], torch.bfloat16, True),
846+
]
847+
)
848+
def forward(self, x):
849+
return torch.relu(x)
850+
851+
852+
@register_test_case(module_factory=lambda: ElementwiseReluBFloat16Module())
853+
def ElementwiseReluModule_bfloat16(module, tu: TestUtils):
854+
module.forward(tu.rand(4, 2, low=-1).to(torch.bfloat16))
855+
856+
857+
# ==============================================================================
858+
859+
860+
class ElementwiseReluFloat16Module(torch.nn.Module):
861+
def __init__(self):
862+
super().__init__()
863+
864+
@export
865+
@annotate_args(
866+
[
867+
None,
868+
([-1, -1], torch.float16, True),
869+
]
870+
)
871+
def forward(self, x):
872+
return torch.relu(x)
873+
874+
875+
@register_test_case(module_factory=lambda: ElementwiseReluFloat16Module())
876+
def ElementwiseReluModule_float16(module, tu: TestUtils):
877+
module.forward(tu.rand(4, 2, low=-1).to(torch.float16))
878+
879+
880+
# ==============================================================================
881+
882+
837883
class QuantizedReluInt8(torch.nn.Module):
838884
def __init__(self):
839885
super().__init__()
@@ -1769,6 +1815,62 @@ def ElementwiseClampModule_basic(module, tu: TestUtils):
17691815
# ==============================================================================
17701816

17711817

1818+
class ElementwiseClampBFloat16Module(torch.nn.Module):
1819+
def __init__(self):
1820+
super().__init__()
1821+
1822+
@export
1823+
@annotate_args(
1824+
[
1825+
None,
1826+
([-1, -1], torch.bfloat16, True),
1827+
]
1828+
)
1829+
def forward(self, x):
1830+
float_min = torch.clamp(x, min=-2.0)
1831+
int_min = torch.clamp(x, min=-3)
1832+
float_max = torch.clamp(x, max=2.0)
1833+
int_max = torch.clamp(x, max=3)
1834+
both = torch.clamp(x, min=-5, max=5)
1835+
return float_min, int_min, float_max, int_max, both
1836+
1837+
1838+
@register_test_case(module_factory=lambda: ElementwiseClampBFloat16Module())
1839+
def ElementwiseClampModule_bfloat16(module, tu: TestUtils):
1840+
module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.bfloat16))
1841+
1842+
1843+
# ==============================================================================
1844+
1845+
1846+
class ElementwiseClampFloat16Module(torch.nn.Module):
1847+
def __init__(self):
1848+
super().__init__()
1849+
1850+
@export
1851+
@annotate_args(
1852+
[
1853+
None,
1854+
([-1, -1], torch.float16, True),
1855+
]
1856+
)
1857+
def forward(self, x):
1858+
float_min = torch.clamp(x, min=-2.0)
1859+
int_min = torch.clamp(x, min=-3)
1860+
float_max = torch.clamp(x, max=2.0)
1861+
int_max = torch.clamp(x, max=3)
1862+
both = torch.clamp(x, min=-5, max=5)
1863+
return float_min, int_min, float_max, int_max, both
1864+
1865+
1866+
@register_test_case(module_factory=lambda: ElementwiseClampFloat16Module())
1867+
def ElementwiseClampModule_float16(module, tu: TestUtils):
1868+
module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.float16))
1869+
1870+
1871+
# ==============================================================================
1872+
1873+
17721874
class ElementwiseClampMinModule(torch.nn.Module):
17731875
def __init__(self):
17741876
super().__init__()
@@ -1795,6 +1897,58 @@ def ElementwiseClampMinModule_basic(module, tu: TestUtils):
17951897
# ==============================================================================
17961898

17971899

1900+
class ElementwiseClampMinBFloat16Module(torch.nn.Module):
1901+
def __init__(self):
1902+
super().__init__()
1903+
1904+
@export
1905+
@annotate_args(
1906+
[
1907+
None,
1908+
([-1, -1], torch.bfloat16, True),
1909+
]
1910+
)
1911+
def forward(self, x):
1912+
float_min = torch.ops.aten.clamp_min(x, min=-2.0)
1913+
int_min = torch.ops.aten.clamp_min(x, min=2)
1914+
min = torch.ops.aten.clamp_min(x, min=11.0)
1915+
return float_min, int_min, min
1916+
1917+
1918+
@register_test_case(module_factory=lambda: ElementwiseClampMinBFloat16Module())
1919+
def ElementwiseClampMinModule_bfloat16(module, tu: TestUtils):
1920+
module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.bfloat16))
1921+
1922+
1923+
# ==============================================================================
1924+
1925+
1926+
class ElementwiseClampMinFloat16Module(torch.nn.Module):
1927+
def __init__(self):
1928+
super().__init__()
1929+
1930+
@export
1931+
@annotate_args(
1932+
[
1933+
None,
1934+
([-1, -1], torch.float16, True),
1935+
]
1936+
)
1937+
def forward(self, x):
1938+
float_min = torch.ops.aten.clamp_min(x, min=-2.0)
1939+
int_min = torch.ops.aten.clamp_min(x, min=2)
1940+
min = torch.ops.aten.clamp_min(x, min=11.0)
1941+
return float_min, int_min, min
1942+
1943+
1944+
@register_test_case(module_factory=lambda: ElementwiseClampMinFloat16Module())
1945+
def ElementwiseClampMinModule_float16(module, tu: TestUtils):
1946+
module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.float16))
1947+
1948+
1949+
# ==============================================================================
1950+
1951+
17981952
class ElementwiseClampMaxModule(torch.nn.Module):
17991953
def __init__(self):
18001954
super().__init__()
@@ -1821,6 +1975,58 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils):
18211975
# ==============================================================================
18221976

18231977

1978+
class ElementwiseClampMaxBFloat16Module(torch.nn.Module):
1979+
def __init__(self):
1980+
super().__init__()
1981+
1982+
@export
1983+
@annotate_args(
1984+
[
1985+
None,
1986+
([-1, -1], torch.bfloat16, True),
1987+
]
1988+
)
1989+
def forward(self, x):
1990+
float_max = torch.ops.aten.clamp_max(x, max=2.0)
1991+
int_max = torch.ops.aten.clamp_max(x, max=3)
1992+
max = torch.ops.aten.clamp_max(x, max=-11.0)
1993+
return float_max, int_max, max
1994+
1995+
1996+
@register_test_case(module_factory=lambda: ElementwiseClampMaxBFloat16Module())
1997+
def ElementwiseClampMaxModule_bfloat16(module, tu: TestUtils):
1998+
module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.bfloat16))
1999+
2000+
2001+
# ==============================================================================
2002+
2003+
2004+
class ElementwiseClampMaxFloat16Module(torch.nn.Module):
2005+
def __init__(self):
2006+
super().__init__()
2007+
2008+
@export
2009+
@annotate_args(
2010+
[
2011+
None,
2012+
([-1, -1], torch.float16, True),
2013+
]
2014+
)
2015+
def forward(self, x):
2016+
float_max = torch.ops.aten.clamp_max(x, max=2.0)
2017+
int_max = torch.ops.aten.clamp_max(x, max=3)
2018+
max = torch.ops.aten.clamp_max(x, max=-11.0)
2019+
return float_max, int_max, max
2020+
2021+
2022+
@register_test_case(module_factory=lambda: ElementwiseClampMaxFloat16Module())
2023+
def ElementwiseClampMaxModule_float16(module, tu: TestUtils):
2024+
module.forward(tu.rand(3, 5, low=-10, high=10).to(torch.float16))
2025+
2026+
2027+
# ==============================================================================
2028+
2029+
18242030
class ElementwiseClampTensorFloatModule(torch.nn.Module):
18252031
def __init__(self):
18262032
super().__init__()

0 commit comments

Comments
 (0)