@@ -870,25 +870,21 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
870
870
Value self = adaptor.getSelf ();
871
871
auto selfTy = cast<TensorType>(self.getType ());
872
872
873
- // Maps to tosa.clamp which has both int and fp limits.
874
- int64_t clampMin = 0 ;
875
- Value clampIn = self;
876
873
if (!selfTy) {
877
874
return rewriter.notifyMatchFailure (op,
878
875
" Only Tensor types supported in TOSA" );
879
876
}
880
877
881
- // Rescale the clampIn for quantized types. TBD
878
+ // Rescale self for quantized types. TBD
882
879
if (!isa<mlir::FloatType>(selfTy.getElementType ())) {
883
880
return rewriter.notifyMatchFailure (
884
881
op, " Only floating-point datatype legalization currently supported" );
885
882
}
886
883
884
+ // Maps to tosa.clamp
887
885
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888
886
rewriter.replaceOpWithNewOp <tosa::ClampOp>(
889
- op, getTypeConverter ()->convertType (op.getType ()), clampIn,
890
- rewriter.getI64IntegerAttr (clampMin),
891
- rewriter.getI64IntegerAttr (std::numeric_limits<int32_t >::max ()),
887
+ op, getTypeConverter ()->convertType (op.getType ()), self,
892
888
rewriter.getF32FloatAttr (0 .0f ),
893
889
rewriter.getF32FloatAttr (std::numeric_limits<float >::max ()),
894
890
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
@@ -5120,53 +5116,88 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
5120
5116
return rewriter.notifyMatchFailure (
5121
5117
op, " only tensor types input are currently supported" );
5122
5118
5123
- IntegerAttr min_int =
5124
- rewriter.getI64IntegerAttr (std::numeric_limits<int64_t >::min ());
5125
- IntegerAttr max_int =
5126
- rewriter.getI64IntegerAttr (std::numeric_limits<int64_t >::max ());
5127
- FloatAttr min_fp =
5128
- rewriter.getF32FloatAttr (std::numeric_limits<float >::lowest ());
5129
- FloatAttr max_fp =
5130
- rewriter.getF32FloatAttr (std::numeric_limits<float >::max ());
5131
-
5132
- auto getValAttr = [&](Value operand, IntegerAttr &intAttr,
5133
- FloatAttr &fpAttr) -> LogicalResult {
5134
- double valFloat;
5135
- int64_t valInt;
5136
- if (matchPattern (operand, m_TorchConstantFloat (&valFloat))) {
5137
- intAttr = rewriter.getI64IntegerAttr (static_cast <int64_t >(valFloat));
5138
- fpAttr = rewriter.getF32FloatAttr (static_cast <float >(valFloat));
5139
- } else if (matchPattern (operand, m_TorchConstantInt (&valInt))) {
5140
- intAttr = rewriter.getI64IntegerAttr (valInt);
5141
- fpAttr = rewriter.getF32FloatAttr (static_cast <float >(valInt));
5119
+ auto outType =
5120
+ dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
5121
+ auto outElemTy = outType.getElementType ();
5122
+
5123
+ int64_t minInt, maxInt;
5124
+ double minFloat, maxFloat;
5125
+ bool isMinNotNone = false ;
5126
+ bool isMaxNotNone = false ;
5127
+
5128
+ auto isMinInt = matchPattern (op.getMin (), m_TorchConstantInt (&minInt));
5129
+ auto isMinFloat = matchPattern (op.getMin (), m_TorchConstantFloat (&minFloat));
5130
+ if (isMinInt) {
5131
+ minFloat = static_cast <float >(minInt);
5132
+ isMinNotNone = true ;
5133
+ } else if (isMinFloat) {
5134
+ minInt = static_cast <int64_t >(minFloat);
5135
+ isMinNotNone = true ;
5136
+ } else {
5137
+ if (succeeded (checkNotNone (rewriter, op, op.getMin ())))
5138
+ return rewriter.notifyMatchFailure (op,
5139
+ " min attr should be a torch constant" );
5140
+ }
5141
+
5142
+ auto isMaxInt = matchPattern (op.getMax (), m_TorchConstantInt (&maxInt));
5143
+ auto isMaxFloat = matchPattern (op.getMax (), m_TorchConstantFloat (&maxFloat));
5144
+ if (isMaxInt) {
5145
+ maxFloat = static_cast <float >(maxInt);
5146
+ isMaxNotNone = true ;
5147
+ } else if (isMaxFloat) {
5148
+ maxInt = static_cast <int64_t >(maxFloat);
5149
+ isMaxNotNone = true ;
5150
+ } else {
5151
+ if (succeeded (checkNotNone (rewriter, op, op.getMax ())))
5152
+ return rewriter.notifyMatchFailure (op,
5153
+ " max attr should be a torch constant" );
5154
+ }
5155
+
5156
+ if (!isa<mlir::FloatType>(outElemTy)) {
5157
+ IntegerAttr minIntAttr, maxIntAttr;
5158
+ if (outElemTy.isInteger (8 )) {
5159
+ minIntAttr = rewriter.getIntegerAttr (
5160
+ outElemTy,
5161
+ isMinNotNone ? minInt : std::numeric_limits<int8_t >::min ());
5162
+ maxIntAttr = rewriter.getIntegerAttr (
5163
+ outElemTy,
5164
+ isMaxNotNone ? maxInt : std::numeric_limits<int8_t >::max ());
5165
+ } else if (outElemTy.isInteger (16 )) {
5166
+ minIntAttr = rewriter.getIntegerAttr (
5167
+ outElemTy,
5168
+ isMinNotNone ? minInt : std::numeric_limits<int16_t >::min ());
5169
+ maxIntAttr = rewriter.getIntegerAttr (
5170
+ outElemTy,
5171
+ isMaxNotNone ? maxInt : std::numeric_limits<int16_t >::max ());
5172
+ } else if (outElemTy.isInteger (32 )) {
5173
+ minIntAttr = rewriter.getIntegerAttr (
5174
+ outElemTy,
5175
+ isMinNotNone ? minInt : std::numeric_limits<int32_t >::min ());
5176
+ maxIntAttr = rewriter.getIntegerAttr (
5177
+ outElemTy,
5178
+ isMaxNotNone ? maxInt : std::numeric_limits<int32_t >::max ());
5179
+ } else if (outElemTy.isInteger (64 )) {
5180
+ minIntAttr = rewriter.getI64IntegerAttr (
5181
+ isMinNotNone ? minInt : std::numeric_limits<int64_t >::min ());
5182
+ maxIntAttr = rewriter.getI64IntegerAttr (
5183
+ isMaxNotNone ? maxInt : std::numeric_limits<int64_t >::max ());
5142
5184
} else {
5143
- return failure ( );
5185
+ return rewriter. notifyMatchFailure (op, " Unsupported integer type " );
5144
5186
}
5145
- return success ();
5146
- };
5147
5187
5148
- LogicalResult minAttrResult = getValAttr (op.getMin (), min_int, min_fp);
5149
- LogicalResult maxAttrResult = getValAttr (op.getMax (), max_int, max_fp);
5150
- if (failed (minAttrResult) && failed (maxAttrResult)) {
5151
- return rewriter.notifyMatchFailure (
5152
- op, " either `min` or `max` should be a torch constant" );
5153
- }
5154
- if (failed (minAttrResult) &&
5155
- succeeded (checkNotNone (rewriter, op, op.getMin ()))) {
5156
- return rewriter.notifyMatchFailure (op,
5157
- " min attr should be a torch constant" );
5158
- }
5159
- if (failed (maxAttrResult) &&
5160
- succeeded (checkNotNone (rewriter, op, op.getMax ()))) {
5161
- return rewriter.notifyMatchFailure (op,
5162
- " max attr should be a torch constant" );
5163
- }
5188
+ rewriter.replaceOpWithNewOp <tosa::ClampOp>(
5189
+ op, outType, adaptor.getSelf (), minIntAttr, maxIntAttr,
5190
+ /* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
5191
+ } else {
5192
+ FloatAttr minFloatAttr = rewriter.getF32FloatAttr (
5193
+ isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5194
+ FloatAttr maxFloatAttr = rewriter.getF32FloatAttr (
5195
+ isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5164
5196
5165
- // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
5166
- auto outType = getTypeConverter ()->convertType (op.getType ());
5167
- rewriter.replaceOpWithNewOp <tosa::ClampOp>(
5168
- op, outType, adaptor.getSelf (), min_int, max_int, min_fp, max_fp,
5169
- /* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
5197
+ rewriter.replaceOpWithNewOp <tosa::ClampOp>(
5198
+ op, outType, adaptor.getSelf (), minFloatAttr, maxFloatAttr,
5199
+ /* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
5200
+ }
5170
5201
5171
5202
return success ();
5172
5203
}
@@ -6788,12 +6819,12 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
6788
6819
border_y);
6789
6820
normalize (inputWidth, outputWidth, scale_x_n, scale_x_d, offset_x, border_x);
6790
6821
6791
- DenseI64ArrayAttr scale = rewriter. getDenseI64ArrayAttr (
6792
- {scale_y_n, scale_y_d, scale_x_n, scale_x_d});
6793
- DenseI64ArrayAttr offset =
6794
- rewriter. getDenseI64ArrayAttr ( {offset_y, offset_x});
6795
- DenseI64ArrayAttr border =
6796
- rewriter. getDenseI64ArrayAttr ( {border_y, border_x});
6822
+ auto scale = tosa::getTosaConstShape (
6823
+ rewriter, op-> getLoc (), {scale_y_n, scale_y_d, scale_x_n, scale_x_d});
6824
+ auto offset =
6825
+ tosa::getTosaConstShape ( rewriter, op-> getLoc (), {offset_y, offset_x});
6826
+ auto border =
6827
+ tosa::getTosaConstShape ( rewriter, op-> getLoc (), {border_y, border_x});
6797
6828
StringAttr modeAttr = rewriter.getStringAttr (mode);
6798
6829
6799
6830
auto resizeOpResult =
@@ -8486,8 +8517,6 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
8486
8517
zi = rewriter
8487
8518
.create <tosa::ClampOp>(
8488
8519
op->getLoc (), resultType, self,
8489
- rewriter.getI64IntegerAttr (static_cast <int64_t >(eps)),
8490
- rewriter.getI64IntegerAttr (static_cast <int64_t >(1 - eps)),
8491
8520
rewriter.getF32FloatAttr (static_cast <float >(eps)),
8492
8521
rewriter.getF32FloatAttr (static_cast <float >(1 - eps)),
8493
8522
/* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ))
0 commit comments