Skip to content

Commit 155680c

Browse files
[MLIR][TORCH] Add E2E support for aten.as_strided op (#4269)
This commit adds the e2e support for the aten.as_strided op by decomposing it into a series of other torch operations. Fixes #4191. The failing tests for Tosa config are tracked by #4272. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent bc1dae9 commit 155680c

File tree

4 files changed

+254
-13
lines changed

4 files changed

+254
-13
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12580,6 +12580,198 @@ class DecomposeAtenRoundDecimalsOp
1258012580
};
1258112581
} // namespace
1258212582

12583+
namespace {
12584+
class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
12585+
public:
12586+
using OpRewritePattern<AtenAsStridedOp>::OpRewritePattern;
12587+
LogicalResult matchAndRewrite(AtenAsStridedOp op,
12588+
PatternRewriter &rewriter) const override {
12589+
12590+
// The `aten.as_strided` operation is decomposed into a series of
12591+
// operations that compute the indices based on the provided sizes and
12592+
// strides, and then index into the flattened input tensor as follows:
12593+
12594+
// input_flat = input.view(-1)
12595+
//
12596+
// for dim, s in enumerate(self.size):
12597+
// arange = torch.arange(s)
12598+
// view_shape = []
12599+
// for i in range(len(self.size)):
12600+
// if i == dim:
12601+
// view_shape.append(-1)
12602+
// else:
12603+
// view_shape.append(1)
12604+
// arange = arange.view(view_shape)
12605+
// if dim != 0:
12606+
// idx = idx + arange * self.stride[dim]
12607+
//
12608+
// # Flatten indices and add offset
12609+
// final_indices = idx.reshape(-1) + self.storage_offset
12610+
//
12611+
// # Index the flattened input tensor
12612+
// output = input_flat[final_indices]
12613+
//
12614+
// # Reshape to desired output size
12615+
// return output.view(self.size)
12616+
12617+
Location loc = op.getLoc();
12618+
MLIRContext *context = op->getContext();
12619+
Value input = op.getSelf();
12620+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
12621+
12622+
if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
12623+
return rewriter.notifyMatchFailure(op, "input must have known sizes");
12624+
12625+
SmallVector<int64_t> sizesInts;
12626+
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
12627+
return rewriter.notifyMatchFailure(
12628+
op, "sizes must be a list of constant ints");
12629+
12630+
SmallVector<int64_t> stridesInts;
12631+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts)))
12632+
return rewriter.notifyMatchFailure(
12633+
op, "strides must be a list of constant ints");
12634+
12635+
int64_t storageOffset = 0;
12636+
if (!isa<Torch::NoneType>(op.getStorageOffset().getType())) {
12637+
if (!matchPattern(op.getStorageOffset(),
12638+
m_TorchConstantInt(&storageOffset)))
12639+
return rewriter.notifyMatchFailure(
12640+
op, "storage_offset must be a constant integer");
12641+
}
12642+
12643+
ArrayRef<int64_t> inputSizes = inputType.getSizes();
12644+
int64_t inputRank = inputSizes.size();
12645+
int64_t resultRank = sizesInts.size();
12646+
12647+
Value cstZero =
12648+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
12649+
if (inputRank > 1) {
12650+
// If the input is not a 1-d tensor, we need to flatten it
12651+
// to a 1D tensor before applying the strided indexing.
12652+
int64_t flattenedInputSize = 1;
12653+
for (int64_t size : inputSizes)
12654+
flattenedInputSize *= size;
12655+
12656+
auto flattenedInputTy =
12657+
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
12658+
{flattenedInputSize}, inputType.getOptionalDtype()));
12659+
12660+
Value end = rewriter.create<ConstantIntOp>(
12661+
loc, rewriter.getI64IntegerAttr(inputRank - 1));
12662+
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenedInputTy,
12663+
input, cstZero, end);
12664+
}
12665+
12666+
Value cstOne =
12667+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
12668+
Value cstMinusOne =
12669+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
12670+
12671+
SmallVector<int64_t> viewShapeInts(resultRank, 1);
12672+
SmallVector<Value> viewShapeListElems(resultRank, cstOne);
12673+
12674+
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
12675+
Value finalIndices;
12676+
for (unsigned dim = 0; dim < sizesInts.size(); dim++) {
12677+
int64_t size = sizesInts[dim];
12678+
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
12679+
Value end =
12680+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(size));
12681+
12682+
auto arangeType =
12683+
ValueTensorType::get(context, llvm::ArrayRef(size), si64Type);
12684+
Value index = rewriter.create<Torch::AtenArangeOp>(
12685+
loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);
12686+
12687+
// Set the current dimension to -1 for broadcasting
12688+
viewShapeInts[dim] = -1;
12689+
viewShapeListElems[dim] = cstMinusOne;
12690+
12691+
Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
12692+
loc, Torch::ListType::get(Torch::IntType::get(context)),
12693+
viewShapeListElems);
12694+
12695+
auto viewType = ValueTensorType::get(
12696+
context, llvm::ArrayRef(viewShapeInts), si64Type);
12697+
index = rewriter.create<AtenViewOp>(loc, viewType, index, viewShapeList);
12698+
12699+
// Multiply the index with the stride for the current dimension
12700+
Value cstStride = rewriter.create<ConstantIntOp>(
12701+
loc, rewriter.getI64IntegerAttr(stridesInts[dim]));
12702+
index = rewriter.create<AtenMulScalarOp>(loc, viewType, index, cstStride);
12703+
12704+
// Reset the current dimension to 1 for the next iteration
12705+
viewShapeInts[dim] = 1;
12706+
viewShapeListElems[dim] = cstOne;
12707+
12708+
if (dim == 0) {
12709+
finalIndices = index;
12710+
continue;
12711+
}
12712+
12713+
// calculate common shape for broadcast
12714+
SmallVector<int64_t> broadcastShape;
12715+
SmallVector<Value> broadcastShapeValue;
12716+
computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
12717+
broadcastShapeValue);
12718+
Type broadcastType = ValueTensorType::get(
12719+
context, llvm::ArrayRef(broadcastShape), si64Type);
12720+
12721+
finalIndices = rewriter.create<AtenAddTensorOp>(
12722+
loc, broadcastType, finalIndices, index, cstOne);
12723+
}
12724+
12725+
int64_t flattenedResultSize = 1;
12726+
for (int64_t size : sizesInts)
12727+
flattenedResultSize *= size;
12728+
12729+
// Flattening the indices and adding the storage offset
12730+
finalIndices = rewriter.create<AtenFlattenUsingIntsOp>(
12731+
loc,
12732+
ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12733+
si64Type),
12734+
finalIndices, cstZero, cstMinusOne); // -1 means flatten all
12735+
12736+
if (storageOffset != 0) {
12737+
Value cstStorageOffset = rewriter.create<ConstantIntOp>(
12738+
loc, rewriter.getI64IntegerAttr(storageOffset));
12739+
finalIndices = rewriter.create<AtenAddScalarOp>(
12740+
loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne);
12741+
}
12742+
12743+
// Index the flattened input tensor
12744+
Type listElemType =
12745+
inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
12746+
/*optionalDtype=*/nullptr);
12747+
Value indicesList = rewriter.create<Torch::PrimListConstructOp>(
12748+
loc, Torch::ListType::get(listElemType),
12749+
SmallVector<Value>{finalIndices});
12750+
12751+
auto flattenedResultTy =
12752+
ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12753+
inputType.getOptionalDtype());
12754+
Value result = rewriter.create<AtenIndexTensorOp>(loc, flattenedResultTy,
12755+
input, indicesList);
12756+
12757+
// Reshape the result to the desired output size
12758+
SmallVector<Value> sizesIntsValues;
12759+
for (int64_t size : sizesInts) {
12760+
sizesIntsValues.push_back(rewriter.create<ConstantIntOp>(
12761+
loc, rewriter.getI64IntegerAttr(size)));
12762+
}
12763+
Value resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
12764+
loc, Torch::ListType::get(Torch::IntType::get(context)),
12765+
sizesIntsValues);
12766+
result =
12767+
rewriter.create<AtenViewOp>(loc, op.getType(), result, resultSizeList);
12768+
12769+
rewriter.replaceOp(op, result);
12770+
return success();
12771+
}
12772+
};
12773+
} // namespace
12774+
1258312775
namespace {
1258412776
class DecomposeComplexOpsPass
1258512777
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12904,6 +13096,7 @@ class DecomposeComplexOpsPass
1290413096
patterns);
1290513097
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
1290613098
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
13099+
addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);
1290713100

1290813101
GreedyRewriteConfig config;
1290913102
config.setUseTopDownTraversal(true);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
590590
target.addIllegalOp<AtenLogaddexpOp>();
591591
target.addIllegalOp<AtenLogaddexp2Op>();
592592
target.addIllegalOp<AtenKlDivOp>();
593+
target.addIllegalOp<AtenAsStridedOp>();
593594

594595
for (auto &opName : backendLegalOpsSet) {
595596
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -487,17 +487,6 @@
487487
"ViewSizeFromOtherTensor_basic",
488488
"ViewDtypeStaticModule_basic",
489489
"WeightNormInterfaceModule_basic",
490-
# Error: `aten.as_strided` op is not supported
491-
"ChunkListUnpackDynamic_Module_basic",
492-
"ChunkListUnpackUnevenDynamic_Module_basic",
493-
"ChunkListUnpackUneven_Module_basic",
494-
"ChunkListUnpack_Module_basic",
495-
"SplitTensorGetItem_Module_basic",
496-
"SplitTensorLastSmallerModule_basic",
497-
"SplitTensorListUnpackModule_basic",
498-
"SplitTensorNegativeDimModule_basic",
499-
"SplitWithSizesListUnpackModule_basic",
500-
"SplitWithSizes_Module_basic",
501490
"AdaptiveAvgPool1dGeneralDynamic_basic",
502491
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
503492
"AdaptiveAvgPool1dStaticLargerOutput_basic",
@@ -526,8 +515,6 @@
526515
"ReflectionPad3dModuleRight_basic",
527516
"ReflectionPad3dModuleFront_basic",
528517
"ReflectionPad3dModuleBack_basic",
529-
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
530-
"NativeGroupNormModule_basic",
531518
}
532519

533520
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
@@ -994,6 +981,8 @@
994981
"NativeGroupNormModule_basic",
995982
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
996983
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
984+
"AtenAsStridedModule_basic",
985+
"AtenAsStridedNoStorageOffsetModule_basic",
997986
}
998987

999988
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3986,6 +3975,19 @@
39863975
"ReplicationPad1dModule_3DInput_basic",
39873976
"ReplicationPad3dModule_basic",
39883977
"ReplicationPad3dModuleSingleIntPad_basic",
3978+
"AtenAsStridedModule_basic",
3979+
"AtenAsStridedNoStorageOffsetModule_basic",
3980+
"ChunkListUnpackDynamic_Module_basic",
3981+
"ChunkListUnpackUnevenDynamic_Module_basic",
3982+
"ChunkListUnpackUneven_Module_basic",
3983+
"ChunkListUnpack_Module_basic",
3984+
"NativeGroupNormModule_basic",
3985+
"SplitTensorGetItem_Module_basic",
3986+
"SplitTensorLastSmallerModule_basic",
3987+
"SplitTensorListUnpackModule_basic",
3988+
"SplitTensorNegativeDimModule_basic",
3989+
"SplitWithSizesListUnpackModule_basic",
3990+
"SplitWithSizes_Module_basic",
39893991
}
39903992

39913993
ONNX_TOSA_CRASHING_SET = {

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6878,3 +6878,48 @@ def forward(self, x):
68786878
@register_test_case(module_factory=lambda: Aten_AssertScalar())
68796879
def Aten_AssertScalar_basic(module, tu: TestUtils):
68806880
module.forward(torch.tensor(4))
6881+
6882+
6883+
# ==============================================================================
6884+
6885+
6886+
class AtenAsStridedModule(torch.nn.Module):
6887+
def __init__(self):
6888+
super().__init__()
6889+
6890+
@export
6891+
@annotate_args(
6892+
[
6893+
None,
6894+
([4, 5, 6], torch.float32, True),
6895+
]
6896+
)
6897+
def forward(self, x):
6898+
return torch.ops.aten.as_strided(
6899+
x, size=(2, 2), stride=(3, 3), storage_offset=1
6900+
)
6901+
6902+
6903+
@register_test_case(module_factory=lambda: AtenAsStridedModule())
6904+
def AtenAsStridedModule_basic(module, tu: TestUtils):
6905+
module.forward(torch.randn(4, 5, 6))
6906+
6907+
6908+
class AtenAsStridedNoStorageOffsetModule(torch.nn.Module):
6909+
def __init__(self):
6910+
super().__init__()
6911+
6912+
@export
6913+
@annotate_args(
6914+
[
6915+
None,
6916+
([12, 13], torch.float32, True),
6917+
]
6918+
)
6919+
def forward(self, x):
6920+
return torch.ops.aten.as_strided(x, size=(3, 4), stride=(2, 5))
6921+
6922+
6923+
@register_test_case(module_factory=lambda: AtenAsStridedNoStorageOffsetModule())
6924+
def AtenAsStridedNoStorageOffsetModule_basic(module, tu: TestUtils):
6925+
module.forward(torch.randn(12, 13))

0 commit comments

Comments
 (0)