@@ -12580,6 +12580,198 @@ class DecomposeAtenRoundDecimalsOp
12580
12580
};
12581
12581
} // namespace
12582
12582
12583
+ namespace {
12584
+ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
12585
+ public:
12586
+ using OpRewritePattern<AtenAsStridedOp>::OpRewritePattern;
12587
+ LogicalResult matchAndRewrite(AtenAsStridedOp op,
12588
+ PatternRewriter &rewriter) const override {
12589
+
12590
+ // The `aten.as_strided` operation is decomposed into a series of
12591
+ // operations that compute the indices based on the provided sizes and
12592
+ // strides, and then index into the flattened input tensor as follows:
12593
+
12594
+ // input_flat = input.view(-1)
12595
+ //
12596
+ // for dim, s in enumerate(self.size):
12597
+ // arange = torch.arange(s)
12598
+ // view_shape = []
12599
+ // for i in range(len(self.size)):
12600
+ // if i == dim:
12601
+ // view_shape.append(-1)
12602
+ // else:
12603
+ // view_shape.append(1)
12604
+ // arange = arange.view(view_shape)
12605
+ // if dim != 0:
12606
+ // idx = idx + arange * self.stride[dim]
12607
+ //
12608
+ // # Flatten indices and add offset
12609
+ // final_indices = idx.reshape(-1) + self.storage_offset
12610
+ //
12611
+ // # Index the flattened input tensor
12612
+ // output = input_flat[final_indices]
12613
+ //
12614
+ // # Reshape to desired output size
12615
+ // return output.view(self.size)
12616
+
12617
+ Location loc = op.getLoc();
12618
+ MLIRContext *context = op->getContext();
12619
+ Value input = op.getSelf();
12620
+ auto inputType = dyn_cast<BaseTensorType>(input.getType());
12621
+
12622
+ if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
12623
+ return rewriter.notifyMatchFailure(op, "input must have known sizes");
12624
+
12625
+ SmallVector<int64_t> sizesInts;
12626
+ if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
12627
+ return rewriter.notifyMatchFailure(
12628
+ op, "sizes must be a list of constant ints");
12629
+
12630
+ SmallVector<int64_t> stridesInts;
12631
+ if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts)))
12632
+ return rewriter.notifyMatchFailure(
12633
+ op, "strides must be a list of constant ints");
12634
+
12635
+ int64_t storageOffset = 0;
12636
+ if (!isa<Torch::NoneType>(op.getStorageOffset().getType())) {
12637
+ if (!matchPattern(op.getStorageOffset(),
12638
+ m_TorchConstantInt(&storageOffset)))
12639
+ return rewriter.notifyMatchFailure(
12640
+ op, "storage_offset must be a constant integer");
12641
+ }
12642
+
12643
+ ArrayRef<int64_t> inputSizes = inputType.getSizes();
12644
+ int64_t inputRank = inputSizes.size();
12645
+ int64_t resultRank = sizesInts.size();
12646
+
12647
+ Value cstZero =
12648
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
12649
+ if (inputRank > 1) {
12650
+ // If the input is not a 1-d tensor, we need to flatten it
12651
+ // to a 1D tensor before applying the strided indexing.
12652
+ int64_t flattenedInputSize = 1;
12653
+ for (int64_t size : inputSizes)
12654
+ flattenedInputSize *= size;
12655
+
12656
+ auto flattenedInputTy =
12657
+ cast<BaseTensorType>(inputType.getWithSizesAndDtype(
12658
+ {flattenedInputSize}, inputType.getOptionalDtype()));
12659
+
12660
+ Value end = rewriter.create<ConstantIntOp>(
12661
+ loc, rewriter.getI64IntegerAttr(inputRank - 1));
12662
+ input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenedInputTy,
12663
+ input, cstZero, end);
12664
+ }
12665
+
12666
+ Value cstOne =
12667
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
12668
+ Value cstMinusOne =
12669
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
12670
+
12671
+ SmallVector<int64_t> viewShapeInts(resultRank, 1);
12672
+ SmallVector<Value> viewShapeListElems(resultRank, cstOne);
12673
+
12674
+ auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
12675
+ Value finalIndices;
12676
+ for (unsigned dim = 0; dim < sizesInts.size(); dim++) {
12677
+ int64_t size = sizesInts[dim];
12678
+ Value cstNone = rewriter.create<ConstantNoneOp>(loc);
12679
+ Value end =
12680
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(size));
12681
+
12682
+ auto arangeType =
12683
+ ValueTensorType::get(context, llvm::ArrayRef(size), si64Type);
12684
+ Value index = rewriter.create<Torch::AtenArangeOp>(
12685
+ loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);
12686
+
12687
+ // Set the current dimension to -1 for broadcasting
12688
+ viewShapeInts[dim] = -1;
12689
+ viewShapeListElems[dim] = cstMinusOne;
12690
+
12691
+ Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
12692
+ loc, Torch::ListType::get(Torch::IntType::get(context)),
12693
+ viewShapeListElems);
12694
+
12695
+ auto viewType = ValueTensorType::get(
12696
+ context, llvm::ArrayRef(viewShapeInts), si64Type);
12697
+ index = rewriter.create<AtenViewOp>(loc, viewType, index, viewShapeList);
12698
+
12699
+ // Multiply the index with the stride for the current dimension
12700
+ Value cstStride = rewriter.create<ConstantIntOp>(
12701
+ loc, rewriter.getI64IntegerAttr(stridesInts[dim]));
12702
+ index = rewriter.create<AtenMulScalarOp>(loc, viewType, index, cstStride);
12703
+
12704
+ // Reset the current dimension to 1 for the next iteration
12705
+ viewShapeInts[dim] = 1;
12706
+ viewShapeListElems[dim] = cstOne;
12707
+
12708
+ if (dim == 0) {
12709
+ finalIndices = index;
12710
+ continue;
12711
+ }
12712
+
12713
+ // calculate common shape for broadcast
12714
+ SmallVector<int64_t> broadcastShape;
12715
+ SmallVector<Value> broadcastShapeValue;
12716
+ computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
12717
+ broadcastShapeValue);
12718
+ Type broadcastType = ValueTensorType::get(
12719
+ context, llvm::ArrayRef(broadcastShape), si64Type);
12720
+
12721
+ finalIndices = rewriter.create<AtenAddTensorOp>(
12722
+ loc, broadcastType, finalIndices, index, cstOne);
12723
+ }
12724
+
12725
+ int64_t flattenedResultSize = 1;
12726
+ for (int64_t size : sizesInts)
12727
+ flattenedResultSize *= size;
12728
+
12729
+ // Flattening the indices and adding the storage offset
12730
+ finalIndices = rewriter.create<AtenFlattenUsingIntsOp>(
12731
+ loc,
12732
+ ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12733
+ si64Type),
12734
+ finalIndices, cstZero, cstMinusOne); // -1 means flatten all
12735
+
12736
+ if (storageOffset != 0) {
12737
+ Value cstStorageOffset = rewriter.create<ConstantIntOp>(
12738
+ loc, rewriter.getI64IntegerAttr(storageOffset));
12739
+ finalIndices = rewriter.create<AtenAddScalarOp>(
12740
+ loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne);
12741
+ }
12742
+
12743
+ // Index the flattened input tensor
12744
+ Type listElemType =
12745
+ inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
12746
+ /*optionalDtype=*/nullptr);
12747
+ Value indicesList = rewriter.create<Torch::PrimListConstructOp>(
12748
+ loc, Torch::ListType::get(listElemType),
12749
+ SmallVector<Value>{finalIndices});
12750
+
12751
+ auto flattenedResultTy =
12752
+ ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12753
+ inputType.getOptionalDtype());
12754
+ Value result = rewriter.create<AtenIndexTensorOp>(loc, flattenedResultTy,
12755
+ input, indicesList);
12756
+
12757
+ // Reshape the result to the desired output size
12758
+ SmallVector<Value> sizesIntsValues;
12759
+ for (int64_t size : sizesInts) {
12760
+ sizesIntsValues.push_back(rewriter.create<ConstantIntOp>(
12761
+ loc, rewriter.getI64IntegerAttr(size)));
12762
+ }
12763
+ Value resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
12764
+ loc, Torch::ListType::get(Torch::IntType::get(context)),
12765
+ sizesIntsValues);
12766
+ result =
12767
+ rewriter.create<AtenViewOp>(loc, op.getType(), result, resultSizeList);
12768
+
12769
+ rewriter.replaceOp(op, result);
12770
+ return success();
12771
+ }
12772
+ };
12773
+ } // namespace
12774
+
12583
12775
namespace {
12584
12776
class DecomposeComplexOpsPass
12585
12777
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12904,6 +13096,7 @@ class DecomposeComplexOpsPass
12904
13096
patterns);
12905
13097
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
12906
13098
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
13099
+ addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);
12907
13100
12908
13101
GreedyRewriteConfig config;
12909
13102
config.setUseTopDownTraversal(true);
0 commit comments