@@ -526,6 +526,41 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
526526 return success ();
527527}
528528
529+ template <>
530+ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
531+ AtenSqueezeOp op, OpAdaptor adaptor,
532+ ConversionPatternRewriter &rewriter) const {
533+
534+ Value self = adaptor.self ();
535+ auto selfTy = self.getType ().template cast <RankedTensorType>();
536+
537+ if (!selfTy)
538+ return op.emitError (" Only ranked tensor types supported in TOSA argmax" );
539+
540+ auto selfShape = selfTy.getShape ();
541+
542+ SmallVector<int64_t > newOutputShape;
543+ for (auto &dim : selfShape) {
544+ if (dim != 1 )
545+ newOutputShape.push_back (dim);
546+ }
547+
548+ auto resultTy = getTypeConverter ()
549+ ->convertType (op.getResult ().getType ())
550+ .cast <RankedTensorType>();
551+ auto resultElemTy = resultTy.getElementType ();
552+
553+ auto newOutputTy = RankedTensorType::get (newOutputShape, resultElemTy);
554+
555+ auto reshapeOp = rewriter.create <tosa::ReshapeOp>(
556+ op->getLoc (), getTypeConverter ()->convertType (newOutputTy), self,
557+ rewriter.getI64ArrayAttr (newOutputShape));
558+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
559+ op, getTypeConverter ()->convertType (newOutputTy), reshapeOp);
560+
561+ return success ();
562+ }
563+
529564} // namespace
530565
531566// -----------------------------------------------------------------------------
@@ -624,6 +659,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
624659 INSERT_ATENOP_PATTERN (AtenMulTensorOp);
625660 INSERT_ATENOP_PATTERN (AtenDivTensorOp);
626661 INSERT_ATENOP_PATTERN (AtenArgmaxOp);
662+ INSERT_ATENOP_PATTERN (AtenSqueezeOp);
627663#undef INSERT_ATENOP_PATTERN
628664
629665 if (failed (applyPartialConversion (getOperation (), target,
0 commit comments