@@ -871,8 +871,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
871
871
ConversionPatternRewriter &rewriter) const {
872
872
Value self = adaptor.getSelf ();
873
873
auto selfTy = cast<TensorType>(self.getType ());
874
+ auto outTy =
875
+ dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
876
+ auto outElemTy = outTy.getElementType ();
874
877
875
- if (!selfTy) {
878
+ if (!selfTy || !outTy ) {
876
879
return rewriter.notifyMatchFailure (op,
877
880
" Only Tensor types supported in TOSA" );
878
881
}
@@ -883,12 +886,27 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
883
886
op, " Only floating-point datatype legalization currently supported" );
884
887
}
885
888
889
+ FloatAttr minFloatAttr, maxFloatAttr;
890
+ if (outElemTy.isF16 ()) {
891
+ minFloatAttr = rewriter.getF16FloatAttr (0 .0f );
892
+ maxFloatAttr = rewriter.getF16FloatAttr (Float16Max);
893
+ } else if (outElemTy.isBF16 ()) {
894
+ minFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), 0 .0f );
895
+ maxFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), BFloat16Max);
896
+ } else if (outElemTy.isF32 ()) {
897
+ minFloatAttr = rewriter.getF32FloatAttr (0 .0f );
898
+ maxFloatAttr = rewriter.getF32FloatAttr (std::numeric_limits<float >::max ());
899
+ } else if (outElemTy.isF64 ()) {
900
+ minFloatAttr = rewriter.getF64FloatAttr (0 .0f );
901
+ maxFloatAttr = rewriter.getF64FloatAttr (std::numeric_limits<double >::max ());
902
+ } else {
903
+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
904
+ }
905
+
886
906
// Maps to tosa.clamp
887
907
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888
908
rewriter.replaceOpWithNewOp <tosa::ClampOp>(
889
- op, getTypeConverter ()->convertType (op.getType ()), self,
890
- rewriter.getF32FloatAttr (0 .0f ),
891
- rewriter.getF32FloatAttr (std::numeric_limits<float >::max ()),
909
+ op, outTy, self, minFloatAttr, maxFloatAttr,
892
910
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
893
911
return success ();
894
912
}
@@ -5186,10 +5204,30 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
5186
5204
op, outType, adaptor.getSelf (), minIntAttr, maxIntAttr,
5187
5205
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
5188
5206
} else {
5189
- FloatAttr minFloatAttr = rewriter.getF32FloatAttr (
5190
- isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5191
- FloatAttr maxFloatAttr = rewriter.getF32FloatAttr (
5192
- isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5207
+ FloatAttr minFloatAttr, maxFloatAttr;
5208
+ if (outElemTy.isF16 ()) {
5209
+ minFloatAttr =
5210
+ rewriter.getF16FloatAttr (isMinNotNone ? minFloat : Float16Lowest);
5211
+ maxFloatAttr =
5212
+ rewriter.getF16FloatAttr (isMaxNotNone ? maxFloat : Float16Max);
5213
+ } else if (outElemTy.isBF16 ()) {
5214
+ minFloatAttr = rewriter.getFloatAttr (
5215
+ rewriter.getBF16Type (), isMinNotNone ? minFloat : BFloat16Lowest);
5216
+ maxFloatAttr = rewriter.getFloatAttr (
5217
+ rewriter.getBF16Type (), isMaxNotNone ? maxFloat : BFloat16Max);
5218
+ } else if (outElemTy.isF32 ()) {
5219
+ minFloatAttr = rewriter.getF32FloatAttr (
5220
+ isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5221
+ maxFloatAttr = rewriter.getF32FloatAttr (
5222
+ isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5223
+ } else if (outElemTy.isF64 ()) {
5224
+ minFloatAttr = rewriter.getF64FloatAttr (
5225
+ isMinNotNone ? minFloat : std::numeric_limits<double >::lowest ());
5226
+ maxFloatAttr = rewriter.getF64FloatAttr (
5227
+ isMaxNotNone ? maxFloat : std::numeric_limits<double >::max ());
5228
+ } else {
5229
+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
5230
+ }
5193
5231
5194
5232
rewriter.replaceOpWithNewOp <tosa::ClampOp>(
5195
5233
op, outType, adaptor.getSelf (), minFloatAttr, maxFloatAttr,
@@ -8547,14 +8585,29 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
8547
8585
8548
8586
auto zi = self;
8549
8587
8588
+ FloatAttr minFloatAttr, maxFloatAttr;
8589
+ if (resultElemTy.isF16 ()) {
8590
+ minFloatAttr = rewriter.getF16FloatAttr (eps);
8591
+ maxFloatAttr = rewriter.getF16FloatAttr (1 - eps);
8592
+ } else if (resultElemTy.isBF16 ()) {
8593
+ minFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), eps);
8594
+ maxFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), 1 - eps);
8595
+ } else if (resultElemTy.isF32 ()) {
8596
+ minFloatAttr = rewriter.getF32FloatAttr (eps);
8597
+ maxFloatAttr = rewriter.getF32FloatAttr (1 - eps);
8598
+ } else if (resultElemTy.isF64 ()) {
8599
+ minFloatAttr = rewriter.getF64FloatAttr (eps);
8600
+ maxFloatAttr = rewriter.getF64FloatAttr (1 - eps);
8601
+ } else {
8602
+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
8603
+ }
8604
+
8550
8605
// Clamp input to [eps, 1 - eps] when eps is not None
8551
8606
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
8552
8607
if (!isEpsNone) {
8553
8608
zi = rewriter
8554
8609
.create <tosa::ClampOp>(
8555
- op->getLoc (), resultType, self,
8556
- rewriter.getF32FloatAttr (static_cast <float >(eps)),
8557
- rewriter.getF32FloatAttr (static_cast <float >(1 - eps)),
8610
+ op->getLoc (), resultType, self, minFloatAttr, maxFloatAttr,
8558
8611
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ))
8559
8612
.getResult ();
8560
8613
}
0 commit comments