Skip to content

Commit 0cd95b5

Browse files
authored
[tosa] Support for Torch.squeeze (#487)
1 parent 396ab35 commit 0cd95b5

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,7 @@
3838
"BoolTensorReturnTrueModule_basic",
3939
"BoolTensorReturnMixedModule_basic",
4040
"ElementwiseRsqrtModule_basic",
41+
"SqueezeModule_static",
42+
"SqueezeModule_noUnitDim",
43+
"SqueezeModule_allUnitDim",
4144
}

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,41 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
526526
return success();
527527
}
528528

529+
template <>
530+
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
531+
AtenSqueezeOp op, OpAdaptor adaptor,
532+
ConversionPatternRewriter &rewriter) const {
533+
534+
Value self = adaptor.self();
535+
auto selfTy = self.getType().template cast<RankedTensorType>();
536+
537+
if (!selfTy)
538+
return op.emitError("Only ranked tensor types supported in TOSA argmax");
539+
540+
auto selfShape = selfTy.getShape();
541+
542+
SmallVector<int64_t> newOutputShape;
543+
for (auto &dim : selfShape) {
544+
if (dim != 1)
545+
newOutputShape.push_back(dim);
546+
}
547+
548+
auto resultTy = getTypeConverter()
549+
->convertType(op.getResult().getType())
550+
.cast<RankedTensorType>();
551+
auto resultElemTy = resultTy.getElementType();
552+
553+
auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy);
554+
555+
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
556+
op->getLoc(), getTypeConverter()->convertType(newOutputTy), self,
557+
rewriter.getI64ArrayAttr(newOutputShape));
558+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
559+
op, getTypeConverter()->convertType(newOutputTy), reshapeOp);
560+
561+
return success();
562+
}
563+
529564
} // namespace
530565

531566
// -----------------------------------------------------------------------------
@@ -624,6 +659,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
624659
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
625660
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
626661
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
662+
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
627663
#undef INSERT_ATENOP_PATTERN
628664

629665
if (failed(applyPartialConversion(getOperation(), target,

0 commit comments

Comments
 (0)