Skip to content

Commit b4842d9

Browse files
authored
[tosa] Implement squeeze.dim support (#511)
Templated variants for squeeze and squeeze.dim
1 parent 3c40539 commit b4842d9

File tree

2 files changed

+108
-28
lines changed

2 files changed

+108
-28
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,7 @@
4545
"TModuleRank0_basic",
4646
"ElementwiseToDtypeIdentityModule_basic",
4747
"View1DFoldModule_basic",
48+
"SqueezeDimModule_static",
49+
"SqueezeDimModule_identity",
50+
"SqueezeDimModule_unitDim",
4851
}

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 105 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -526,40 +526,111 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
526526
return success();
527527
}
528528

529-
template <>
530-
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
531-
AtenSqueezeOp op, OpAdaptor adaptor,
532-
ConversionPatternRewriter &rewriter) const {
529+
template <typename AtenOpT>
530+
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenOpT> {
531+
public:
532+
using OpConversionPattern<AtenOpT>::OpConversionPattern;
533+
using OpAdaptor = typename AtenOpT::Adaptor;
533534

534-
Value self = adaptor.self();
535-
auto selfTy = self.getType().template cast<RankedTensorType>();
535+
// Each variant must implement corresponding parameter parsing options
536+
virtual LogicalResult
537+
generateSqueezedShape(AtenOpT op, RankedTensorType selfTy,
538+
ConversionPatternRewriter &rewriter,
539+
SmallVector<int64_t> &squeezedShape) const {
540+
return rewriter.notifyMatchFailure(
541+
op, "Unimplemented dim/dim-list parsing function");
542+
}
536543

537-
if (!selfTy)
538-
return op.emitError("Only ranked tensor types supported in TOSA argmax");
544+
// Common rewriter for all squeeze ops, calls the specific implementation of
545+
// generateSqueezedShape() needed for the op variant.
546+
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
547+
ConversionPatternRewriter &rewriter) const {
548+
Value self = adaptor.self();
549+
auto selfTy = self.getType().template cast<RankedTensorType>();
550+
551+
if (!selfTy)
552+
return op.emitError("Only ranked tensor types supported in TOSA argmax");
553+
554+
SmallVector<int64_t> newOutputShape;
555+
if (failed(generateSqueezedShape(op, selfTy, rewriter, newOutputShape)))
556+
return op.emitError("Squeeze could not compute new shape");
557+
558+
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
559+
->convertType(op.getResult().getType())
560+
.template cast<RankedTensorType>();
561+
auto resultElemTy = resultTy.getElementType();
539562

540-
auto selfShape = selfTy.getShape();
563+
auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy);
541564

542-
SmallVector<int64_t> newOutputShape;
543-
for (auto &dim : selfShape) {
544-
if (dim != 1)
545-
newOutputShape.push_back(dim);
565+
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
566+
op->getLoc(),
567+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
568+
newOutputTy),
569+
self, rewriter.getI64ArrayAttr(newOutputShape));
570+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
571+
op,
572+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
573+
newOutputTy),
574+
reshapeOp);
575+
576+
return success();
546577
}
578+
};
547579

548-
auto resultTy = getTypeConverter()
549-
->convertType(op.getResult().getType())
550-
.cast<RankedTensorType>();
551-
auto resultElemTy = resultTy.getElementType();
580+
template <typename AtenOpT>
581+
class ConvertAtenSqueezeOneDimOp : public ConvertAtenSqueezeOp<AtenOpT> {
582+
using ConvertAtenSqueezeOp<AtenOpT>::ConvertAtenSqueezeOp;
583+
using OpAdaptor = typename AtenOpT::Adaptor;
584+
585+
LogicalResult
586+
generateSqueezedShape(AtenOpT op, RankedTensorType selfTy,
587+
ConversionPatternRewriter &rewriter,
588+
SmallVector<int64_t> &squeezedShape) const {
589+
int64_t squeezeDim;
590+
if (!matchPattern(op.dim(), m_TorchConstantInt(&squeezeDim)))
591+
return rewriter.notifyMatchFailure(op,
592+
"non-const dim parameter unsupported");
552593

553-
auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy);
594+
// Handle negative dim
595+
if (squeezeDim < 0)
596+
squeezeDim = squeezeDim + selfTy.getRank();
554597

555-
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
556-
op->getLoc(), getTypeConverter()->convertType(newOutputTy), self,
557-
rewriter.getI64ArrayAttr(newOutputShape));
558-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
559-
op, getTypeConverter()->convertType(newOutputTy), reshapeOp);
598+
auto selfShape = selfTy.getShape();
560599

561-
return success();
562-
}
600+
// Only dims statically known to have size=1 are reduced.
601+
// Dynamic dims are treated as unknowns and will not be squeezed
602+
// even if dim parameter says it should be.
603+
uint32_t dimNum = 0;
604+
for (auto &dim : selfShape) {
605+
if (dim != 1 || squeezeDim != dimNum)
606+
squeezedShape.push_back(dim);
607+
dimNum++;
608+
}
609+
610+
return success();
611+
}
612+
};
613+
614+
template <typename AtenOpT>
615+
class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
616+
using ConvertAtenSqueezeOp<AtenOpT>::ConvertAtenSqueezeOp;
617+
using OpAdaptor = typename AtenOpT::Adaptor;
618+
619+
LogicalResult
620+
generateSqueezedShape(AtenOpT op, RankedTensorType selfTy,
621+
ConversionPatternRewriter &rewriter,
622+
SmallVector<int64_t> &squeezedShape) const {
623+
auto selfShape = selfTy.getShape();
624+
625+
// Dims that may dynamically resolve to 1 are not reduced here. Only
626+
// compile-time resolvable dims are handled here.
627+
for (auto &dim : selfShape) {
628+
if (dim != 1)
629+
squeezedShape.push_back(dim);
630+
}
631+
return success();
632+
}
633+
};
563634

564635
} // namespace
565636

@@ -606,8 +677,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
606677
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
607678
#undef INSERT_UNARY_PATTERN
608679

609-
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
610-
target.addIllegalOp<AtenOp>(); \
680+
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
681+
target.addIllegalOp<AtenOp>(); \
611682
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
612683
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
613684
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
@@ -650,6 +721,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
650721
mlir::tosa::convertReduceSumOp)
651722
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
652723

724+
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
725+
target.addIllegalOp<AtenOp>(); \
726+
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
727+
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp)
728+
INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp)
729+
#undef INSERT_SQUEEZE_OP_PATTERN
730+
653731
#define INSERT_ATENOP_PATTERN(AtenOp) \
654732
target.addIllegalOp<AtenOp>(); \
655733
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
@@ -659,7 +737,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
659737
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
660738
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
661739
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
662-
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
663740
#undef INSERT_ATENOP_PATTERN
664741

665742
if (failed(applyPartialConversion(getOperation(), target,

0 commit comments

Comments
 (0)