Skip to content

Commit 51da49c

Browse files
authored
[Torch] Add decomposition for 1d torch.nonzero (llvm#3876)
2d static nonzero also work. But 2d dynamic need to be fixed next.
1 parent 061bbc5 commit 51da49c

File tree

3 files changed

+260
-1
lines changed

3 files changed

+260
-1
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5705,6 +5705,240 @@ class DecomposeAtenConvolutionBackwardOp
57055705
};
57065706
} // namespace
57075707

5708+
/**
5709+
* # one dim input
5710+
* t = torch.tensor([0, 0, 1, 1, 0, 0]
5711+
* # t_flat:[0, 0, 1, 1, 0, 0]
5712+
* t_flat = t.flatten(0, 0)
5713+
* nonzero_mask = t_flat != 0
5714+
* # nonzero_mask:[0, 0, 1, 1, 0, 0]
5715+
* nonzero_mask = nonzero_mask.long()
5716+
* # destination_indices:[-1, -1, 0, 1, 1, 1]
5717+
* destination_indices = torch.cumsum(nonzero_mask, 0) - 1
5718+
* # destination_indices_clamp:[0, 0, 0, 1, 1, 1]
5719+
* destination_indices_clamp = torch.clamp(destination_indices, min=0)
5720+
* # iota:[0, 0, 2, 3, 0, 0]
5721+
* iota = torch.arange(t_flat.size(0)) * nonzero_mask
5722+
* # scatter_self:[0, 0, 0, 0, 0, 0]
5723+
* scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
5724+
* # compacted:[2, 3, 0, 0, 0, 0]
5725+
* compacted = torch.scatter_add(
5726+
* scatter_self, dim=0, index=destination_indices_clamp, src=iota
5727+
* )
5728+
* # result_flat:[2, 3]
5729+
* result_flat = compacted[: torch.sum(nonzero_mask)]
5730+
*
5731+
* # multi dim support
5732+
* original_shape = t.shape
5733+
* # input_shape_tensor:[6]
5734+
* input_shape_tensor = torch.tensor(original_shape)
5735+
* strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
5736+
*
5737+
* one = torch.tensor([1])
5738+
* if(t.dim() > 1):
5739+
* slicedStrides = strides[1:-1]
5740+
* strides = torch.cat([slicedStrides, one])
5741+
* else:
5742+
* strides = one
5743+
* # a: tensor([[2], [3]]) torch.Size([2, 1])
5744+
* a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
5745+
* # b: tensor([[1]]) torch.Size([1, 1])
5746+
* b = strides.unsqueeze(0)
5747+
* # c: tensor([[2], [3]]) torch.Size([2, 1])
5748+
* c = a // b
5749+
* # result: tensor([[2], [3]]) torch.Size([2, 1])
5750+
* result = c % input_shape_tensor
5751+
*/
5752+
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
5753+
using OpRewritePattern::OpRewritePattern;
5754+
LogicalResult matchAndRewrite(AtenNonzeroOp op,
5755+
PatternRewriter &rewriter) const override {
5756+
Location loc = op.getLoc();
5757+
auto resultType = cast<BaseTensorType>(op.getType());
5758+
auto intType = resultType.getDtype();
5759+
Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType);
5760+
auto constantZero =
5761+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
5762+
auto constantOne =
5763+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
5764+
std::function<Value(Value)> makeOneElementList = [&](Value element) {
5765+
auto listType = Torch::ListType::get(element.getType());
5766+
return rewriter.create<PrimListConstructOp>(loc, listType,
5767+
ArrayRef<Value>{element});
5768+
};
5769+
5770+
Value input = op.getSelf();
5771+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
5772+
int64_t inputRank = inputType.getSizes().size();
5773+
5774+
// t_flat = t.flatten() # torch.flatten(t, 0, 0)
5775+
int64_t flattenedSize = 1;
5776+
if (inputType.hasSizes()) {
5777+
for (auto size : inputType.getSizes()) {
5778+
flattenedSize *= size;
5779+
}
5780+
} else {
5781+
flattenedSize = kUnknownSize;
5782+
}
5783+
5784+
auto flattendInputShape = SmallVector<int64_t>{flattenedSize};
5785+
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
5786+
flattendInputShape, inputType.getOptionalDtype());
5787+
5788+
// %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 :
5789+
auto inputDimsEnd = rewriter.create<ConstantIntOp>(
5790+
loc, rewriter.getI64IntegerAttr(inputRank - 1));
5791+
Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
5792+
loc, flattenedInputType, input, constantZero /*inputDimsStart*/,
5793+
inputDimsEnd /*inputDimsEnd*/);
5794+
5795+
// nonzero_mask = (t_flat != 0)
5796+
auto boolMaskType = inputType.getWithSizesAndDtype(
5797+
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
5798+
Value boolMask = rewriter.create<AtenNeScalarOp>(
5799+
loc, boolMaskType, flattenedInput, constantZero);
5800+
5801+
// nonzero_mask = nonzero_mask.int()
5802+
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
5803+
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
5804+
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
5805+
flattenedInputType.getOptionalSizes(), intType);
5806+
Value intMask = rewriter.create<AtenToDtypeOp>(
5807+
loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst);
5808+
5809+
// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
5810+
Value cumulativeSum = rewriter.create<AtenCumsumOp>(
5811+
loc, intMaskType, intMask, constantZero, noneCst);
5812+
Value subtracted = rewriter.create<AtenSubScalarOp>(
5813+
loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne);
5814+
5815+
// destination_indices = torch.clamp(destination_indices, min=0)
5816+
Value indices = rewriter.create<AtenClampMinOp>(loc, intMaskType,
5817+
subtracted, constantZero);
5818+
5819+
// iota = torch.arange(len(t_flat)) * nonzero_mask
5820+
Value end = rewriter.create<AtenSizeIntOp>(loc, flattenedInput,
5821+
/*dim=*/constantZero);
5822+
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
5823+
loc, intMaskType, /*start*/ constantZero, /*end*/ end,
5824+
/*step*/ constantOne, noneCst, noneCst, noneCst, noneCst);
5825+
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, intMaskType,
5826+
rangeTensor, intMask);
5827+
5828+
// scatter_self = torch.zeros_like(t, dtype=torch.int64)
5829+
// AtenFullLike doesn't support index type so we have to use int.
5830+
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
5831+
loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst,
5832+
noneCst, noneCst);
5833+
5834+
// compacted = torch.scatter_add(
5835+
// scatter_self, dim=0, index=destination_indices_clamp, src=iota)
5836+
Value scatteredTensor = rewriter.create<AtenScatterAddOp>(
5837+
loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero,
5838+
/*index=*/indices, /*src=*/multiplied);
5839+
5840+
// result_flat = compacted[:torch.sum(nonzero_mask)]
5841+
auto scalarType = ValueTensorType::get(rewriter.getContext(),
5842+
ArrayRef<int64_t>{}, intType);
5843+
Value sumMask =
5844+
rewriter.create<AtenSumOp>(loc, scalarType, intMask, noneCst);
5845+
Value numNonzero = rewriter.create<AtenIntTensorOp>(loc, sumMask);
5846+
5847+
auto slicedResultType = Torch::ValueTensorType::get(
5848+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, intType);
5849+
Value slicedResult =
5850+
rewriter.create<AtenSliceTensorOp>(loc, slicedResultType,
5851+
/*self=*/scatteredTensor,
5852+
/*dim=*/constantZero,
5853+
/*start=*/noneCst,
5854+
/*end=*/numNonzero,
5855+
/*step=*/constantOne);
5856+
5857+
// TODO fix multidim dynamic support. The following code only work for
5858+
// static multidim. Convert flattened indices back to multi-dimensional
5859+
// indices original_shape = t.shape input_shape_tensor =
5860+
// torch.tensor(original_shape)
5861+
auto shapeType = Torch::ValueTensorType::get(
5862+
rewriter.getContext(), SmallVector<int64_t>{inputRank}, intType);
5863+
SmallVector<Value> shapeValues;
5864+
for (int i = 0; i < inputRank; i++) {
5865+
auto constantI =
5866+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
5867+
Value shape = rewriter.create<AtenSizeIntOp>(loc, input,
5868+
/*dim=*/constantI);
5869+
shapeValues.push_back(shape);
5870+
}
5871+
Value shapeTensorList = rewriter.create<Torch::PrimListConstructOp>(
5872+
loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues);
5873+
Value inputShapeTensor = rewriter.create<Torch::AtenTensorOp>(
5874+
loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst);
5875+
5876+
// strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0)
5877+
Value flippedShape = rewriter.create<AtenFlipOp>(
5878+
loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
5879+
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
5880+
loc, shapeType, flippedShape, constantZero, noneCst);
5881+
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
5882+
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));
5883+
5884+
// strides = torch.cat([strides[1:-1], torch.tensor([1])])
5885+
auto oneTensorType = ValueTensorType::get(rewriter.getContext(),
5886+
SmallVector<int64_t>{1}, intType);
5887+
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
5888+
loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst,
5889+
noneCst);
5890+
5891+
Value strides;
5892+
if (inputRank > 1) {
5893+
// strides[1:-1]
5894+
auto slicedStrideType = Torch::ValueTensorType::get(
5895+
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
5896+
intType);
5897+
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
5898+
loc, rewriter.getI64IntegerAttr(inputRank));
5899+
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
5900+
loc, slicedStrideType, /*self*/ flippedCumulativeProduct,
5901+
/*dim*/ constantZero,
5902+
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);
5903+
// torch.cat
5904+
auto tensorListElementType = Torch::ValueTensorType::get(
5905+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, intType);
5906+
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
5907+
loc, Torch::ListType::get(tensorListElementType),
5908+
SmallVector<Value>{slicedStrides, oneTensor});
5909+
strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList,
5910+
constantZero);
5911+
} else {
5912+
// strides[1:-1] is empty
5913+
strides = oneTensor;
5914+
}
5915+
5916+
// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
5917+
// input_shape_tensor
5918+
auto unsqueezedResultType = ValueTensorType::get(
5919+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, intType);
5920+
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
5921+
loc, unsqueezedResultType, slicedResult, constantOne);
5922+
5923+
auto unsqueezedStridesType = ValueTensorType::get(
5924+
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, intType);
5925+
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
5926+
loc, unsqueezedStridesType, strides, constantZero);
5927+
5928+
auto dividedBroadcastType = ValueTensorType::get(
5929+
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
5930+
intType);
5931+
Value divided = rewriter.create<AtenFloorDivideOp>(
5932+
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);
5933+
5934+
Value modded = rewriter.create<AtenRemainderTensorOp>(
5935+
loc, resultType, divided, inputShapeTensor);
5936+
5937+
rewriter.replaceOp(op, modded);
5938+
return success();
5939+
}
5940+
};
5941+
57085942
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
57095943
namespace {
57105944
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
@@ -11263,6 +11497,7 @@ class DecomposeComplexOpsPass
1126311497
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
1126411498
patterns);
1126511499
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
11500+
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
1126611501
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
1126711502
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
1126811503
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@
399399
"AtenIntBoolOpModule_basic",
400400
"AtenIntMM_basic",
401401
"AtenItemFpOpModule_basic",
402+
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
402403
"Aten_TrilinearModuleVaryingRanks_basic",
403404
"Aten_TrilinearModuleZerodDimBug_basic",
404405
"QuantizedReluInt32_basic",
@@ -628,6 +629,7 @@
628629
"AtenMmQMixedSigni8_basic",
629630
"AtenMmQint8_basic",
630631
"AtenMmQuint8_basic",
632+
"AtenNonzero1DDynamicModule_basic",
631633
"AtenRealView128Module_basic",
632634
"AtenRealView64Module_basic",
633635
"AtenTopKModule_basic",
@@ -3018,7 +3020,6 @@
30183020
"LinalgNormKeepDimComplexModule_basic",
30193021
"LinalgVectorNormComplexModule_basic",
30203022
"LogSoftmaxBackwardModule_basic",
3021-
"MaskedScatterStaticBasic_basic",
30223023
"MaxPool1dCeilModeTrueModule_basic",
30233024
"MaxPool1dModule_basic",
30243025
"MaxPool2dCeilModeTrueModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6430,3 +6430,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils):
64306430
module.forward(
64316431
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
64326432
)
6433+
6434+
6435+
# ==============================================================================
6436+
6437+
6438+
class AtenNonzero1DDynamicModule(torch.nn.Module):
6439+
def __init__(self):
6440+
super().__init__()
6441+
6442+
@export
6443+
@annotate_args(
6444+
[
6445+
None,
6446+
([-1], torch.bool, True),
6447+
]
6448+
)
6449+
def forward(self, x):
6450+
return torch.ops.aten.nonzero(x)
6451+
6452+
6453+
@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule())
6454+
def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils):
6455+
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))

0 commit comments

Comments
 (0)