31
31
32
32
#include " mlir/Dialect/Tosa/Utils/QuantUtils.h"
33
33
34
+ namespace {
35
+
36
+ namespace limits {
37
+ // Float 16 limits
38
+ constexpr float Float16Max = 65504 .0f ;
39
+ constexpr float Float16Lowest = -65504 .0f ;
40
+
41
+ // BFloat 16 limits
42
+ constexpr float BFloat16Max = 3 .38953139e38f;
43
+ constexpr float BFloat16Lowest = -3 .38953139e38f;
44
+ } // namespace limits
45
+
46
+ } // namespace
47
+
34
48
using namespace mlir ;
35
49
using namespace mlir ::torch;
36
50
using namespace mlir ::torch::Torch;
@@ -871,6 +885,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
871
885
ConversionPatternRewriter &rewriter) const {
872
886
Value self = adaptor.getSelf ();
873
887
auto selfTy = cast<TensorType>(self.getType ());
888
+ auto outType =
889
+ dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
890
+ auto outElemTy = outType.getElementType ();
874
891
875
892
if (!selfTy) {
876
893
return rewriter.notifyMatchFailure (op,
@@ -883,12 +900,28 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
883
900
op, " Only floating-point datatype legalization currently supported" );
884
901
}
885
902
903
+ FloatAttr minFloatAttr, maxFloatAttr;
904
+ if (outElemTy.isF16 ()) {
905
+ minFloatAttr = rewriter.getF16FloatAttr (0 .0f );
906
+ maxFloatAttr = rewriter.getF16FloatAttr (limits::Float16Max);
907
+ } else if (outElemTy.isBF16 ()) {
908
+ minFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), 0 .0f );
909
+ maxFloatAttr =
910
+ rewriter.getFloatAttr (rewriter.getBF16Type (), limits::BFloat16Max);
911
+ } else if (outElemTy.isF32 ()) {
912
+ minFloatAttr = rewriter.getF32FloatAttr (0 .0f );
913
+ maxFloatAttr = rewriter.getF32FloatAttr (std::numeric_limits<float >::max ());
914
+ } else if (outElemTy.isF64 ()) {
915
+ minFloatAttr = rewriter.getF64FloatAttr (0 .0f );
916
+ maxFloatAttr = rewriter.getF64FloatAttr (std::numeric_limits<double >::max ());
917
+ } else {
918
+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
919
+ }
920
+
886
921
// Maps to tosa.clamp
887
922
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888
923
rewriter.replaceOpWithNewOp <tosa::ClampOp>(
889
- op, getTypeConverter ()->convertType (op.getType ()), self,
890
- rewriter.getF32FloatAttr (0 .0f ),
891
- rewriter.getF32FloatAttr (std::numeric_limits<float >::max ()),
924
+ op, outType, self, minFloatAttr, maxFloatAttr,
892
925
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
893
926
return success ();
894
927
}
@@ -5186,10 +5219,32 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
5186
5219
op, outType, adaptor.getSelf (), minIntAttr, maxIntAttr,
5187
5220
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
5188
5221
} else {
5189
- FloatAttr minFloatAttr = rewriter.getF32FloatAttr (
5190
- isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5191
- FloatAttr maxFloatAttr = rewriter.getF32FloatAttr (
5192
- isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5222
+ FloatAttr minFloatAttr, maxFloatAttr;
5223
+ if (outElemTy.isF16 ()) {
5224
+ minFloatAttr = rewriter.getF16FloatAttr (
5225
+ isMinNotNone ? minFloat : limits::Float16Lowest);
5226
+ maxFloatAttr = rewriter.getF16FloatAttr (
5227
+ isMaxNotNone ? maxFloat : limits::Float16Max);
5228
+ } else if (outElemTy.isBF16 ()) {
5229
+ minFloatAttr = rewriter.getFloatAttr (
5230
+ rewriter.getBF16Type (),
5231
+ isMinNotNone ? minFloat : limits::BFloat16Lowest);
5232
+ maxFloatAttr =
5233
+ rewriter.getFloatAttr (rewriter.getBF16Type (),
5234
+ isMaxNotNone ? maxFloat : limits::BFloat16Max);
5235
+ } else if (outElemTy.isF32 ()) {
5236
+ minFloatAttr = rewriter.getF32FloatAttr (
5237
+ isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5238
+ maxFloatAttr = rewriter.getF32FloatAttr (
5239
+ isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5240
+ } else if (outElemTy.isF64 ()) {
5241
+ minFloatAttr = rewriter.getF64FloatAttr (
5242
+ isMinNotNone ? minFloat : std::numeric_limits<double >::lowest ());
5243
+ maxFloatAttr = rewriter.getF64FloatAttr (
5244
+ isMaxNotNone ? maxFloat : std::numeric_limits<double >::max ());
5245
+ } else {
5246
+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
5247
+ }
5193
5248
5194
5249
rewriter.replaceOpWithNewOp <tosa::ClampOp>(
5195
5250
op, outType, adaptor.getSelf (), minFloatAttr, maxFloatAttr,
@@ -8547,14 +8602,29 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
8547
8602
8548
8603
auto zi = self;
8549
8604
8605
+ FloatAttr minFloatAttr, maxFloatAttr;
8606
+ if (resultElemTy.isF16 ()) {
8607
+ minFloatAttr = rewriter.getF16FloatAttr (eps);
8608
+ maxFloatAttr = rewriter.getF16FloatAttr (1 - eps);
8609
+ } else if (resultElemTy.isBF16 ()) {
8610
+ minFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), eps);
8611
+ maxFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), 1 - eps);
8612
+ } else if (resultElemTy.isF32 ()) {
8613
+ minFloatAttr = rewriter.getF32FloatAttr (eps);
8614
+ maxFloatAttr = rewriter.getF32FloatAttr (1 - eps);
8615
+ } else if (resultElemTy.isF64 ()) {
8616
+ minFloatAttr = rewriter.getF64FloatAttr (eps);
8617
+ maxFloatAttr = rewriter.getF64FloatAttr (1 - eps);
8618
+ } else {
8619
+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
8620
+ }
8621
+
8550
8622
// Clamp input to [eps, 1 - eps] when eps is not None
8551
8623
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
8552
8624
if (!isEpsNone) {
8553
8625
zi = rewriter
8554
8626
.create <tosa::ClampOp>(
8555
- op->getLoc (), resultType, self,
8556
- rewriter.getF32FloatAttr (static_cast <float >(eps)),
8557
- rewriter.getF32FloatAttr (static_cast <float >(1 - eps)),
8627
+ op->getLoc (), resultType, self, minFloatAttr, maxFloatAttr,
8558
8628
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ))
8559
8629
.getResult ();
8560
8630
}
0 commit comments