@@ -871,8 +871,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
871871 ConversionPatternRewriter &rewriter) const {
872872 Value self = adaptor.getSelf ();
873873 auto selfTy = cast<TensorType>(self.getType ());
874+ auto outTy =
875+ dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
876+ auto outElemTy = outTy.getElementType ();
874877
875- if (!selfTy) {
878+ if (!selfTy || !outTy ) {
876879 return rewriter.notifyMatchFailure (op,
877880 " Only Tensor types supported in TOSA" );
878881 }
@@ -883,12 +886,27 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
883886 op, " Only floating-point datatype legalization currently supported" );
884887 }
885888
889+ FloatAttr minFloatAttr, maxFloatAttr;
890+ if (outElemTy.isF16 ()) {
891+ minFloatAttr = rewriter.getF16FloatAttr (0 .0f );
892+ maxFloatAttr = rewriter.getF16FloatAttr (Float16Max);
893+ } else if (outElemTy.isBF16 ()) {
894+ minFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), 0 .0f );
895+ maxFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), BFloat16Max);
896+ } else if (outElemTy.isF32 ()) {
897+ minFloatAttr = rewriter.getF32FloatAttr (0 .0f );
898+ maxFloatAttr = rewriter.getF32FloatAttr (std::numeric_limits<float >::max ());
899+ } else if (outElemTy.isF64 ()) {
900+ minFloatAttr = rewriter.getF64FloatAttr (0 .0f );
901+ maxFloatAttr = rewriter.getF64FloatAttr (std::numeric_limits<double >::max ());
902+ } else {
903+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
904+ }
905+
886906 // Maps to tosa.clamp
887907 // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
888908 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
889- op, getTypeConverter ()->convertType (op.getType ()), self,
890- rewriter.getF32FloatAttr (0 .0f ),
891- rewriter.getF32FloatAttr (std::numeric_limits<float >::max ()),
909+ op, outTy, self, minFloatAttr, maxFloatAttr,
892910 /* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
893911 return success ();
894912}
@@ -5186,10 +5204,30 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
51865204 op, outType, adaptor.getSelf (), minIntAttr, maxIntAttr,
51875205 /* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ));
51885206 } else {
5189- FloatAttr minFloatAttr = rewriter.getF32FloatAttr (
5190- isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5191- FloatAttr maxFloatAttr = rewriter.getF32FloatAttr (
5192- isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5207+ FloatAttr minFloatAttr, maxFloatAttr;
5208+ if (outElemTy.isF16 ()) {
5209+ minFloatAttr =
5210+ rewriter.getF16FloatAttr (isMinNotNone ? minFloat : Float16Lowest);
5211+ maxFloatAttr =
5212+ rewriter.getF16FloatAttr (isMaxNotNone ? maxFloat : Float16Max);
5213+ } else if (outElemTy.isBF16 ()) {
5214+ minFloatAttr = rewriter.getFloatAttr (
5215+ rewriter.getBF16Type (), isMinNotNone ? minFloat : BFloat16Lowest);
5216+ maxFloatAttr = rewriter.getFloatAttr (
5217+ rewriter.getBF16Type (), isMaxNotNone ? maxFloat : BFloat16Max);
5218+ } else if (outElemTy.isF32 ()) {
5219+ minFloatAttr = rewriter.getF32FloatAttr (
5220+ isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5221+ maxFloatAttr = rewriter.getF32FloatAttr (
5222+ isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5223+ } else if (outElemTy.isF64 ()) {
5224+ minFloatAttr = rewriter.getF64FloatAttr (
5225+ isMinNotNone ? minFloat : std::numeric_limits<double >::lowest ());
5226+ maxFloatAttr = rewriter.getF64FloatAttr (
5227+ isMaxNotNone ? maxFloat : std::numeric_limits<double >::max ());
5228+ } else {
5229+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
5230+ }
51935231
51945232 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
51955233 op, outType, adaptor.getSelf (), minFloatAttr, maxFloatAttr,
@@ -8547,14 +8585,29 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
85478585
85488586 auto zi = self;
85498587
8588+ FloatAttr minFloatAttr, maxFloatAttr;
8589+ if (resultElemTy.isF16 ()) {
8590+ minFloatAttr = rewriter.getF16FloatAttr (eps);
8591+ maxFloatAttr = rewriter.getF16FloatAttr (1 - eps);
8592+ } else if (resultElemTy.isBF16 ()) {
8593+ minFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), eps);
8594+ maxFloatAttr = rewriter.getFloatAttr (rewriter.getBF16Type (), 1 - eps);
8595+ } else if (resultElemTy.isF32 ()) {
8596+ minFloatAttr = rewriter.getF32FloatAttr (eps);
8597+ maxFloatAttr = rewriter.getF32FloatAttr (1 - eps);
8598+ } else if (resultElemTy.isF64 ()) {
8599+ minFloatAttr = rewriter.getF64FloatAttr (eps);
8600+ maxFloatAttr = rewriter.getF64FloatAttr (1 - eps);
8601+ } else {
8602+ return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
8603+ }
8604+
85508605 // Clamp input to [eps, 1 - eps] when eps is not None
85518606 // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
85528607 if (!isEpsNone) {
85538608 zi = rewriter
85548609 .create <tosa::ClampOp>(
8555- op->getLoc (), resultType, self,
8556- rewriter.getF32FloatAttr (static_cast <float >(eps)),
8557- rewriter.getF32FloatAttr (static_cast <float >(1 - eps)),
8610+ op->getLoc (), resultType, self, minFloatAttr, maxFloatAttr,
85588611 /* nan_mode=*/ rewriter.getStringAttr (" PROPAGATE" ))
85598612 .getResult ();
85608613 }
0 commit comments