Skip to content

Commit f231ac1

Browse files
[TOSA] Add F16 and BF16 support for tosa.clamp
Signed-off-by: Justin Ngo <[email protected]> Change-Id: I0f1ed0249018b40d068af3a2d02a93776c0c4a5f
1 parent 46c3888 commit f231ac1

File tree

1 file changed

+80
-10
lines changed

1 file changed

+80
-10
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@
3131

3232
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
3333

34+
namespace {
35+
36+
namespace limits {
37+
// Float 16 limits
38+
constexpr float Float16Max = 65504.0f;
39+
constexpr float Float16Lowest = -65504.0f;
40+
41+
// BFloat 16 limits
42+
constexpr float BFloat16Max = 3.38953139e38f;
43+
constexpr float BFloat16Lowest = -3.38953139e38f;
44+
} // namespace limits
45+
46+
} // namespace
47+
3448
using namespace mlir;
3549
using namespace mlir::torch;
3650
using namespace mlir::torch::Torch;
@@ -871,6 +885,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
871885
ConversionPatternRewriter &rewriter) const {
872886
Value self = adaptor.getSelf();
873887
auto selfTy = cast<TensorType>(self.getType());
888+
auto outType =
889+
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
890+
auto outElemTy = outType.getElementType();
874891

875892
if (!selfTy) {
876893
return rewriter.notifyMatchFailure(op,
@@ -883,12 +900,28 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
883900
op, "Only floating-point datatype legalization currently supported");
884901
}
885902

903+
FloatAttr minFloatAttr, maxFloatAttr;
904+
if (outElemTy.isF16()) {
905+
minFloatAttr = rewriter.getF16FloatAttr(0.0f);
906+
maxFloatAttr = rewriter.getF16FloatAttr(limits::Float16Max);
907+
} else if (outElemTy.isBF16()) {
908+
minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 0.0f);
909+
maxFloatAttr =
910+
rewriter.getFloatAttr(rewriter.getBF16Type(), limits::BFloat16Max);
911+
} else if (outElemTy.isF32()) {
912+
minFloatAttr = rewriter.getF32FloatAttr(0.0f);
913+
maxFloatAttr = rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
914+
} else if (outElemTy.isF64()) {
915+
minFloatAttr = rewriter.getF64FloatAttr(0.0f);
916+
maxFloatAttr = rewriter.getF64FloatAttr(std::numeric_limits<double>::max());
917+
} else {
918+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
919+
}
920+
886921
// Maps to tosa.clamp
887922
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888923
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
889-
op, getTypeConverter()->convertType(op.getType()), self,
890-
rewriter.getF32FloatAttr(0.0f),
891-
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()),
924+
op, outType, self, minFloatAttr, maxFloatAttr,
892925
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
893926
return success();
894927
}
@@ -5186,10 +5219,32 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
51865219
op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr,
51875220
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
51885221
} else {
5189-
FloatAttr minFloatAttr = rewriter.getF32FloatAttr(
5190-
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5191-
FloatAttr maxFloatAttr = rewriter.getF32FloatAttr(
5192-
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5222+
FloatAttr minFloatAttr, maxFloatAttr;
5223+
if (outElemTy.isF16()) {
5224+
minFloatAttr = rewriter.getF16FloatAttr(
5225+
isMinNotNone ? minFloat : limits::Float16Lowest);
5226+
maxFloatAttr = rewriter.getF16FloatAttr(
5227+
isMaxNotNone ? maxFloat : limits::Float16Max);
5228+
} else if (outElemTy.isBF16()) {
5229+
minFloatAttr = rewriter.getFloatAttr(
5230+
rewriter.getBF16Type(),
5231+
isMinNotNone ? minFloat : limits::BFloat16Lowest);
5232+
maxFloatAttr =
5233+
rewriter.getFloatAttr(rewriter.getBF16Type(),
5234+
isMaxNotNone ? maxFloat : limits::BFloat16Max);
5235+
} else if (outElemTy.isF32()) {
5236+
minFloatAttr = rewriter.getF32FloatAttr(
5237+
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5238+
maxFloatAttr = rewriter.getF32FloatAttr(
5239+
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5240+
} else if (outElemTy.isF64()) {
5241+
minFloatAttr = rewriter.getF64FloatAttr(
5242+
isMinNotNone ? minFloat : std::numeric_limits<double>::lowest());
5243+
maxFloatAttr = rewriter.getF64FloatAttr(
5244+
isMaxNotNone ? maxFloat : std::numeric_limits<double>::max());
5245+
} else {
5246+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
5247+
}
51935248

51945249
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
51955250
op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr,
@@ -8547,14 +8602,29 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85478602

85488603
auto zi = self;
85498604

8605+
FloatAttr minFloatAttr, maxFloatAttr;
8606+
if (resultElemTy.isF16()) {
8607+
minFloatAttr = rewriter.getF16FloatAttr(eps);
8608+
maxFloatAttr = rewriter.getF16FloatAttr(1 - eps);
8609+
} else if (resultElemTy.isBF16()) {
8610+
minFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), eps);
8611+
maxFloatAttr = rewriter.getFloatAttr(rewriter.getBF16Type(), 1 - eps);
8612+
} else if (resultElemTy.isF32()) {
8613+
minFloatAttr = rewriter.getF32FloatAttr(eps);
8614+
maxFloatAttr = rewriter.getF32FloatAttr(1 - eps);
8615+
} else if (resultElemTy.isF64()) {
8616+
minFloatAttr = rewriter.getF64FloatAttr(eps);
8617+
maxFloatAttr = rewriter.getF64FloatAttr(1 - eps);
8618+
} else {
8619+
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
8620+
}
8621+
85508622
// Clamp input to [eps, 1 - eps] when eps is not None
85518623
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
85528624
if (!isEpsNone) {
85538625
zi = rewriter
85548626
.create<tosa::ClampOp>(
8555-
op->getLoc(), resultType, self,
8556-
rewriter.getF32FloatAttr(static_cast<float>(eps)),
8557-
rewriter.getF32FloatAttr(static_cast<float>(1 - eps)),
8627+
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
85588628
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
85598629
.getResult();
85608630
}

0 commit comments

Comments
 (0)