Skip to content

Commit 83f41f9

Browse files
[TOSA] Add F16 and BF16 support for tosa.clamp
Signed-off-by: Justin Ngo <[email protected]>
1 parent 46c3888 commit 83f41f9

File tree

4 files changed

+298
-11
lines changed

4 files changed

+298
-11
lines changed

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
106106
// Returns the squeezed tensor or failure.
107107
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
108108
Value input, int64_t dim);
109+
110+
// Float 16 limits
111+
constexpr float Float16Max = 65504.0f;
112+
constexpr float Float16Lowest = -65504.0f;
113+
114+
// BFloat 16 limits
115+
constexpr float BFloat16Max = 3.38953139e38f;
116+
constexpr float BFloat16Lowest = -3.38953139e38f;
109117
} // namespace Torch
110118
} // namespace torch
111119
} // namespace mlir

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
871871
ConversionPatternRewriter &rewriter) const {
872872
Value self = adaptor.getSelf();
873873
auto selfTy = cast<TensorType>(self.getType());
874+
auto outTy =
875+
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
876+
auto outElemTy = outTy.getElementType();
874877

875-
if (!selfTy) {
878+
if (!selfTy || !outTy) {
876879
return rewriter.notifyMatchFailure(op,
877880
"Only Tensor types supported in TOSA");
878881
}
@@ -883,12 +886,27 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
883886
op, "Only floating-point datatype legalization currently supported");
884887
}
885888

889+
FloatAttr minFloatAttr, maxFloatAttr;
890+
if (outElemTy.isF16()) {
891+
minFloatAttr = rewriter.getF16FloatAttr(0.0f);
892+
maxFloatAttr = rewriter.getF16FloatAttr(Float16Max);
893+
} else if (outElemTy.isBF16()) {
894+
minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 0.0f);
895+
maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), BFloat16Max);
896+
} else if (outElemTy.isF32()) {
897+
minFloatAttr = rewriter.getF32FloatAttr(0.0f);
898+
maxFloatAttr = rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
899+
} else if (outElemTy.isF64()) {
900+
minFloatAttr = rewriter.getF64FloatAttr(0.0f);
901+
maxFloatAttr = rewriter.getF64FloatAttr(std::numeric_limits<double>::max());
902+
} else {
903+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
904+
}
905+
886906
// Maps to tosa.clamp
887907
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888908
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
889-
op, getTypeConverter()->convertType(op.getType()), self,
890-
rewriter.getF32FloatAttr(0.0f),
891-
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()),
909+
op, outTy, self, minFloatAttr, maxFloatAttr,
892910
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
893911
return success();
894912
}
@@ -5186,10 +5204,30 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
51865204
op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr,
51875205
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
51885206
} else {
5189-
FloatAttr minFloatAttr = rewriter.getF32FloatAttr(
5190-
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5191-
FloatAttr maxFloatAttr = rewriter.getF32FloatAttr(
5192-
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5207+
FloatAttr minFloatAttr, maxFloatAttr;
5208+
if (outElemTy.isF16()) {
5209+
minFloatAttr =
5210+
rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest);
5211+
maxFloatAttr =
5212+
rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max);
5213+
} else if (outElemTy.isBF16()) {
5214+
minFloatAttr = rewriter.getFloatAttr(
5215+
rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest);
5216+
maxFloatAttr = rewriter.getFloatAttr(
5217+
rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max);
5218+
} else if (outElemTy.isF32()) {
5219+
minFloatAttr = rewriter.getF32FloatAttr(
5220+
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5221+
maxFloatAttr = rewriter.getF32FloatAttr(
5222+
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5223+
} else if (outElemTy.isF64()) {
5224+
minFloatAttr = rewriter.getF64FloatAttr(
5225+
isMinNotNone ? minFloat : std::numeric_limits<double>::lowest());
5226+
maxFloatAttr = rewriter.getF64FloatAttr(
5227+
isMaxNotNone ? maxFloat : std::numeric_limits<double>::max());
5228+
} else {
5229+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
5230+
}
51935231

51945232
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
51955233
op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr,
@@ -8547,14 +8585,29 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85478585

85488586
auto zi = self;
85498587

8588+
FloatAttr minFloatAttr, maxFloatAttr;
8589+
if (resultElemTy.isF16()) {
8590+
minFloatAttr = rewriter.getF16FloatAttr(eps);
8591+
maxFloatAttr = rewriter.getF16FloatAttr(1 - eps);
8592+
} else if (resultElemTy.isBF16()) {
8593+
minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), eps);
8594+
maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 1 - eps);
8595+
} else if (resultElemTy.isF32()) {
8596+
minFloatAttr = rewriter.getF32FloatAttr(eps);
8597+
maxFloatAttr = rewriter.getF32FloatAttr(1 - eps);
8598+
} else if (resultElemTy.isF64()) {
8599+
minFloatAttr = rewriter.getF64FloatAttr(eps);
8600+
maxFloatAttr = rewriter.getF64FloatAttr(1 - eps);
8601+
} else {
8602+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
8603+
}
8604+
85508605
// Clamp input to [eps, 1 - eps] when eps is not None
85518606
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
85528607
if (!isEpsNone) {
85538608
zi = rewriter
85548609
.create<tosa::ClampOp>(
8555-
op->getLoc(), resultType, self,
8556-
rewriter.getF32FloatAttr(static_cast<float>(eps)),
8557-
rewriter.getF32FloatAttr(static_cast<float>(1 - eps)),
8610+
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
85588611
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
85598612
.getResult();
85608613
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@
530530
"ReflectionPad3dModuleBack_basic",
531531
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
532532
"NativeGroupNormModule_basic",
533+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
534+
"ElementwiseClampMaxModule_bfloat16",
535+
"ElementwiseClampMinModule_bfloat16",
536+
"ElementwiseClampModule_bfloat16",
537+
"ElementwiseReluModule_bfloat16",
533538
}
534539

535540
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
@@ -988,6 +993,11 @@
988993
"NativeGroupNormModule_basic",
989994
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
990995
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
996+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
997+
"ElementwiseClampMaxModule_bfloat16",
998+
"ElementwiseClampMinModule_bfloat16",
999+
"ElementwiseClampModule_bfloat16",
1000+
"ElementwiseReluModule_bfloat16",
9911001
}
9921002

9931003
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3392,6 +3402,11 @@
33923402
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
33933403
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33943404
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3405+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<?x?xbf16>'
3406+
"ElementwiseClampMaxModule_bfloat16",
3407+
"ElementwiseClampMinModule_bfloat16",
3408+
"ElementwiseClampModule_bfloat16",
3409+
"ElementwiseReluModule_bfloat16",
33953410
}
33963411

33973412
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3958,6 +3973,11 @@
39583973
"ReplicationPad1dModule_3DInput_basic",
39593974
"ReplicationPad3dModule_basic",
39603975
"ReplicationPad3dModuleSingleIntPad_basic",
3976+
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
3977+
"ElementwiseClampMaxModule_bfloat16",
3978+
"ElementwiseClampMinModule_bfloat16",
3979+
"ElementwiseClampModule_bfloat16",
3980+
"ElementwiseReluModule_bfloat16",
39613981
}
39623982

39633983
ONNX_TOSA_CRASHING_SET = {

0 commit comments

Comments
 (0)