@@ -526,40 +526,111 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
526526 return success ();
527527}
528528
529- template <>
530- LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
531- AtenSqueezeOp op, OpAdaptor adaptor,
532- ConversionPatternRewriter &rewriter) const {
529+ template <typename AtenOpT>
530+ class ConvertAtenSqueezeOp : public OpConversionPattern <AtenOpT> {
531+ public:
532+ using OpConversionPattern<AtenOpT>::OpConversionPattern;
533+ using OpAdaptor = typename AtenOpT::Adaptor;
533534
534- Value self = adaptor.self ();
535- auto selfTy = self.getType ().template cast <RankedTensorType>();
535+ // Each variant must implement corresponding parameter parsing options
536+ virtual LogicalResult
537+ generateSqueezedShape (AtenOpT op, RankedTensorType selfTy,
538+ ConversionPatternRewriter &rewriter,
539+ SmallVector<int64_t > &squeezedShape) const {
540+ return rewriter.notifyMatchFailure (
541+ op, " Unimplemented dim/dim-list parsing function" );
542+ }
536543
537- if (!selfTy)
538- return op.emitError (" Only ranked tensor types supported in TOSA argmax" );
544+ // Common rewriter for all squeeze ops, calls the specific implementation of
545+ // generateSqueezedShape() needed for the op variant.
546+ LogicalResult matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
547+ ConversionPatternRewriter &rewriter) const {
548+ Value self = adaptor.self ();
549+ auto selfTy = self.getType ().template cast <RankedTensorType>();
550+
551+ if (!selfTy)
552+ return op.emitError (" Only ranked tensor types supported in TOSA argmax" );
553+
554+ SmallVector<int64_t > newOutputShape;
555+ if (failed (generateSqueezedShape (op, selfTy, rewriter, newOutputShape)))
556+ return op.emitError (" Squeeze could not compute new shape" );
557+
558+ auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter ()
559+ ->convertType (op.getResult ().getType ())
560+ .template cast <RankedTensorType>();
561+ auto resultElemTy = resultTy.getElementType ();
539562
540- auto selfShape = selfTy. getShape ( );
563+ auto newOutputTy = RankedTensorType::get (newOutputShape, resultElemTy );
541564
542- SmallVector<int64_t > newOutputShape;
543- for (auto &dim : selfShape) {
544- if (dim != 1 )
545- newOutputShape.push_back (dim);
565+ auto reshapeOp = rewriter.create <tosa::ReshapeOp>(
566+ op->getLoc (),
567+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
568+ newOutputTy),
569+ self, rewriter.getI64ArrayAttr (newOutputShape));
570+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
571+ op,
572+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
573+ newOutputTy),
574+ reshapeOp);
575+
576+ return success ();
546577 }
578+ };
547579
548- auto resultTy = getTypeConverter ()
549- ->convertType (op.getResult ().getType ())
550- .cast <RankedTensorType>();
551- auto resultElemTy = resultTy.getElementType ();
580+ template <typename AtenOpT>
581+ class ConvertAtenSqueezeOneDimOp : public ConvertAtenSqueezeOp <AtenOpT> {
582+ using ConvertAtenSqueezeOp<AtenOpT>::ConvertAtenSqueezeOp;
583+ using OpAdaptor = typename AtenOpT::Adaptor;
584+
585+ LogicalResult
586+ generateSqueezedShape (AtenOpT op, RankedTensorType selfTy,
587+ ConversionPatternRewriter &rewriter,
588+ SmallVector<int64_t > &squeezedShape) const {
589+ int64_t squeezeDim;
590+ if (!matchPattern (op.dim (), m_TorchConstantInt (&squeezeDim)))
591+ return rewriter.notifyMatchFailure (op,
592+ " non-const dim parameter unsupported" );
552593
553- auto newOutputTy = RankedTensorType::get (newOutputShape, resultElemTy);
594+ // Handle negative dim
595+ if (squeezeDim < 0 )
596+ squeezeDim = squeezeDim + selfTy.getRank ();
554597
555- auto reshapeOp = rewriter.create <tosa::ReshapeOp>(
556- op->getLoc (), getTypeConverter ()->convertType (newOutputTy), self,
557- rewriter.getI64ArrayAttr (newOutputShape));
558- rewriter.replaceOpWithNewOp <tensor::CastOp>(
559- op, getTypeConverter ()->convertType (newOutputTy), reshapeOp);
598+ auto selfShape = selfTy.getShape ();
560599
561- return success ();
562- }
600+ // Only dims statically known to have size=1 are reduced.
601+ // Dynamic dims are treated as unknowns and will not be squeezed
602+ // even if dim parameter says it should be.
603+ uint32_t dimNum = 0 ;
604+ for (auto &dim : selfShape) {
605+ if (dim != 1 || squeezeDim != dimNum)
606+ squeezedShape.push_back (dim);
607+ dimNum++;
608+ }
609+
610+ return success ();
611+ }
612+ };
613+
614+ template <typename AtenOpT>
615+ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp <AtenOpT> {
616+ using ConvertAtenSqueezeOp<AtenOpT>::ConvertAtenSqueezeOp;
617+ using OpAdaptor = typename AtenOpT::Adaptor;
618+
619+ LogicalResult
620+ generateSqueezedShape (AtenOpT op, RankedTensorType selfTy,
621+ ConversionPatternRewriter &rewriter,
622+ SmallVector<int64_t > &squeezedShape) const {
623+ auto selfShape = selfTy.getShape ();
624+
625+ // Dims that may dynamically resolve to 1 are not reduced here. Only
626+ // compile-time resolvable dims are handled here.
627+ for (auto &dim : selfShape) {
628+ if (dim != 1 )
629+ squeezedShape.push_back (dim);
630+ }
631+ return success ();
632+ }
633+ };
563634
564635} // namespace
565636
@@ -606,8 +677,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
606677 INSERT_UNARY_PATTERN (AtenBitwiseNotOp, tosa::BitwiseNotOp)
607678#undef INSERT_UNARY_PATTERN
608679
609- #define INSERT_BINARY_PATTERN (AtenOp, TosaOp ) \
610- target.addIllegalOp <AtenOp>(); \
680+ #define INSERT_BINARY_PATTERN (AtenOp, TosaOp ) \
681+ target.addIllegalOp <AtenOp>(); \
611682 patterns.add <ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
612683 INSERT_BINARY_PATTERN (AtenMaximumOp, tosa::MaximumOp)
613684 INSERT_BINARY_PATTERN (AtenMinimumOp, tosa::MinimumOp)
@@ -650,6 +721,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
650721 mlir::tosa::convertReduceSumOp)
651722#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
652723
724+ #define INSERT_SQUEEZE_OP_PATTERN (AtenOp, TemplateForm ) \
725+ target.addIllegalOp <AtenOp>(); \
726+ patterns.add <TemplateForm<AtenOp>>(typeConverter, context);
727+ INSERT_SQUEEZE_OP_PATTERN (AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
728+ INSERT_SQUEEZE_OP_PATTERN (AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
729+ #undef INSERT_SQUEEZE_OP_PATTERN
730+
653731#define INSERT_ATENOP_PATTERN (AtenOp ) \
654732 target.addIllegalOp <AtenOp>(); \
655733 patterns.add <ConvertAtenOp<AtenOp>>(typeConverter, context);
@@ -659,7 +737,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
659737 INSERT_ATENOP_PATTERN (AtenMulTensorOp);
660738 INSERT_ATENOP_PATTERN (AtenDivTensorOp);
661739 INSERT_ATENOP_PATTERN (AtenArgmaxOp);
662- INSERT_ATENOP_PATTERN (AtenSqueezeOp);
663740#undef INSERT_ATENOP_PATTERN
664741
665742 if (failed (applyPartialConversion (getOperation (), target,
0 commit comments