@@ -10214,6 +10214,31 @@ class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> {
1021410214};
1021510215} // namespace
1021610216
10217+ namespace {
10218+ // decompose aten.argsort to aten.sort
10219+ class DecomposeAtenArgsortOp : public OpRewritePattern <AtenArgsortOp> {
10220+ public:
10221+ using OpRewritePattern::OpRewritePattern;
10222+ LogicalResult matchAndRewrite (AtenArgsortOp op,
10223+ PatternRewriter &rewriter) const override {
10224+ Location loc = op.getLoc ();
10225+ auto context = op.getContext ();
10226+
10227+ Value self = op.getSelf ();
10228+ Value dim = op.getDim ();
10229+ Value descending = op.getDescending ();
10230+ auto selfType = cast<BaseTensorType>(self.getType ());
10231+ auto sortIndicesType = selfType.getWithSizesAndDtype (
10232+ selfType.getOptionalSizes (),
10233+ IntegerType::get (context, 64 , IntegerType::Signed));
10234+ auto sortOpResult = rewriter.create <AtenSortOp>(
10235+ loc, self.getType (), sortIndicesType, self, dim, descending);
10236+ rewriter.replaceOp (op, sortOpResult->getResult (1 ));
10237+ return success ();
10238+ }
10239+ };
10240+ } // namespace
10241+
1021710242namespace {
1021810243
1021910244// / Creates coefficients based on DFT definition, see
@@ -11781,6 +11806,7 @@ class DecomposeComplexOpsPass
1178111806 patterns);
1178211807 addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
1178311808 addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
11809+ addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);
1178411810 addPatternIfTargetOpIsIllegal<DecomposeAtenFftRfftOp>(patterns);
1178511811 addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
1178611812 addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
0 commit comments