@@ -5705,6 +5705,240 @@ class DecomposeAtenConvolutionBackwardOp
57055705};
57065706} // namespace
57075707
5708+ /* *
5709+ * # one dim input
5710+ * t = torch.tensor([0, 0, 1, 1, 0, 0]
5711+ * # t_flat:[0, 0, 1, 1, 0, 0]
5712+ * t_flat = t.flatten(0, 0)
5713+ * nonzero_mask = t_flat != 0
5714+ * # nonzero_mask:[0, 0, 1, 1, 0, 0]
5715+ * nonzero_mask = nonzero_mask.long()
5716+ * # destination_indices:[-1, -1, 0, 1, 1, 1]
5717+ * destination_indices = torch.cumsum(nonzero_mask, 0) - 1
5718+ * # destination_indices_clamp:[0, 0, 0, 1, 1, 1]
5719+ * destination_indices_clamp = torch.clamp(destination_indices, min=0)
5720+ * # iota:[0, 0, 2, 3, 0, 0]
5721+ * iota = torch.arange(t_flat.size(0)) * nonzero_mask
5722+ * # scatter_self:[0, 0, 0, 0, 0, 0]
5723+ * scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
5724+ * # compacted:[2, 3, 0, 0, 0, 0]
5725+ * compacted = torch.scatter_add(
5726+ * scatter_self, dim=0, index=destination_indices_clamp, src=iota
5727+ * )
5728+ * # result_flat:[2, 3]
5729+ * result_flat = compacted[: torch.sum(nonzero_mask)]
5730+ *
5731+ * # multi dim support
5732+ * original_shape = t.shape
5733+ * # input_shape_tensor:[6]
5734+ * input_shape_tensor = torch.tensor(original_shape)
5735+ * strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
5736+ *
5737+ * one = torch.tensor([1])
5738+ * if(t.dim() > 1):
5739+ * slicedStrides = strides[1:-1]
5740+ * strides = torch.cat([slicedStrides, one])
5741+ * else:
5742+ * strides = one
5743+ * # a: tensor([[2], [3]]) torch.Size([2, 1])
5744+ * a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
5745+ * # b: tensor([[1]]) torch.Size([1, 1])
5746+ * b = strides.unsqueeze(0)
5747+ * # c: tensor([[2], [3]]) torch.Size([2, 1])
5748+ * c = a // b
5749+ * # result: tensor([[2], [3]]) torch.Size([2, 1])
5750+ * result = c % input_shape_tensor
5751+ */
5752+ class DecomposeAtenNonzeroOp : public OpRewritePattern <AtenNonzeroOp> {
5753+ using OpRewritePattern::OpRewritePattern;
5754+ LogicalResult matchAndRewrite (AtenNonzeroOp op,
5755+ PatternRewriter &rewriter) const override {
5756+ Location loc = op.getLoc ();
5757+ auto resultType = cast<BaseTensorType>(op.getType ());
5758+ auto intType = resultType.getDtype ();
5759+ Value intTypeValue = getDtypeIntValueForType (rewriter, loc, intType);
5760+ auto constantZero =
5761+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
5762+ auto constantOne =
5763+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
5764+ std::function<Value (Value)> makeOneElementList = [&](Value element) {
5765+ auto listType = Torch::ListType::get (element.getType ());
5766+ return rewriter.create <PrimListConstructOp>(loc, listType,
5767+ ArrayRef<Value>{element});
5768+ };
5769+
5770+ Value input = op.getSelf ();
5771+ auto inputType = dyn_cast<BaseTensorType>(input.getType ());
5772+ int64_t inputRank = inputType.getSizes ().size ();
5773+
5774+ // t_flat = t.flatten() # torch.flatten(t, 0, 0)
5775+ int64_t flattenedSize = 1 ;
5776+ if (inputType.hasSizes ()) {
5777+ for (auto size : inputType.getSizes ()) {
5778+ flattenedSize *= size;
5779+ }
5780+ } else {
5781+ flattenedSize = kUnknownSize ;
5782+ }
5783+
5784+ auto flattendInputShape = SmallVector<int64_t >{flattenedSize};
5785+ auto flattenedInputType = rewriter.getType <Torch::ValueTensorType>(
5786+ flattendInputShape, inputType.getOptionalDtype ());
5787+
5788+ // %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 :
5789+ auto inputDimsEnd = rewriter.create <ConstantIntOp>(
5790+ loc, rewriter.getI64IntegerAttr (inputRank - 1 ));
5791+ Value flattenedInput = rewriter.create <AtenFlattenUsingIntsOp>(
5792+ loc, flattenedInputType, input, constantZero /* inputDimsStart*/ ,
5793+ inputDimsEnd /* inputDimsEnd*/ );
5794+
5795+ // nonzero_mask = (t_flat != 0)
5796+ auto boolMaskType = inputType.getWithSizesAndDtype (
5797+ flattenedInputType.getOptionalSizes (), rewriter.getI1Type ());
5798+ Value boolMask = rewriter.create <AtenNeScalarOp>(
5799+ loc, boolMaskType, flattenedInput, constantZero);
5800+
5801+ // nonzero_mask = nonzero_mask.int()
5802+ Value falseCst = rewriter.create <ConstantBoolOp>(loc, false );
5803+ Value noneCst = rewriter.create <ConstantNoneOp>(loc);
5804+ auto intMaskType = flattenedInputType.getWithSizesAndDtype (
5805+ flattenedInputType.getOptionalSizes (), intType);
5806+ Value intMask = rewriter.create <AtenToDtypeOp>(
5807+ loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst);
5808+
5809+ // destination_indices = torch.cumsum(nonzero_mask, 0) - 1
5810+ Value cumulativeSum = rewriter.create <AtenCumsumOp>(
5811+ loc, intMaskType, intMask, constantZero, noneCst);
5812+ Value subtracted = rewriter.create <AtenSubScalarOp>(
5813+ loc, intMaskType, cumulativeSum, constantOne, /* alpha=*/ constantOne);
5814+
5815+ // destination_indices = torch.clamp(destination_indices, min=0)
5816+ Value indices = rewriter.create <AtenClampMinOp>(loc, intMaskType,
5817+ subtracted, constantZero);
5818+
5819+ // iota = torch.arange(len(t_flat)) * nonzero_mask
5820+ Value end = rewriter.create <AtenSizeIntOp>(loc, flattenedInput,
5821+ /* dim=*/ constantZero);
5822+ Value rangeTensor = rewriter.create <AtenArangeStartStepOp>(
5823+ loc, intMaskType, /* start*/ constantZero, /* end*/ end,
5824+ /* step*/ constantOne, noneCst, noneCst, noneCst, noneCst);
5825+ Value multiplied = rewriter.create <AtenMulTensorOp>(loc, intMaskType,
5826+ rangeTensor, intMask);
5827+
5828+ // scatter_self = torch.zeros_like(t, dtype=torch.int64)
5829+ // AtenFullLike doesn't support index type so we have to use int.
5830+ Value zerosTensor = rewriter.create <AtenZerosLikeOp>(
5831+ loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst,
5832+ noneCst, noneCst);
5833+
5834+ // compacted = torch.scatter_add(
5835+ // scatter_self, dim=0, index=destination_indices_clamp, src=iota)
5836+ Value scatteredTensor = rewriter.create <AtenScatterAddOp>(
5837+ loc, intMaskType, /* self*/ zerosTensor, /* dim=*/ constantZero,
5838+ /* index=*/ indices, /* src=*/ multiplied);
5839+
5840+ // result_flat = compacted[:torch.sum(nonzero_mask)]
5841+ auto scalarType = ValueTensorType::get (rewriter.getContext (),
5842+ ArrayRef<int64_t >{}, intType);
5843+ Value sumMask =
5844+ rewriter.create <AtenSumOp>(loc, scalarType, intMask, noneCst);
5845+ Value numNonzero = rewriter.create <AtenIntTensorOp>(loc, sumMask);
5846+
5847+ auto slicedResultType = Torch::ValueTensorType::get (
5848+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize }, intType);
5849+ Value slicedResult =
5850+ rewriter.create <AtenSliceTensorOp>(loc, slicedResultType,
5851+ /* self=*/ scatteredTensor,
5852+ /* dim=*/ constantZero,
5853+ /* start=*/ noneCst,
5854+ /* end=*/ numNonzero,
5855+ /* step=*/ constantOne);
5856+
5857+ // TODO fix multidim dynamic support. The following code only work for
5858+ // static multidim. Convert flattened indices back to multi-dimensional
5859+ // indices original_shape = t.shape input_shape_tensor =
5860+ // torch.tensor(original_shape)
5861+ auto shapeType = Torch::ValueTensorType::get (
5862+ rewriter.getContext (), SmallVector<int64_t >{inputRank}, intType);
5863+ SmallVector<Value> shapeValues;
5864+ for (int i = 0 ; i < inputRank; i++) {
5865+ auto constantI =
5866+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (i));
5867+ Value shape = rewriter.create <AtenSizeIntOp>(loc, input,
5868+ /* dim=*/ constantI);
5869+ shapeValues.push_back (shape);
5870+ }
5871+ Value shapeTensorList = rewriter.create <Torch::PrimListConstructOp>(
5872+ loc, Torch::ListType::get (shapeValues[0 ].getType ()), shapeValues);
5873+ Value inputShapeTensor = rewriter.create <Torch::AtenTensorOp>(
5874+ loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst);
5875+
5876+ // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0)
5877+ Value flippedShape = rewriter.create <AtenFlipOp>(
5878+ loc, shapeType, inputShapeTensor, makeOneElementList (constantZero));
5879+ Value cumulativeProduct = rewriter.create <AtenCumprodOp>(
5880+ loc, shapeType, flippedShape, constantZero, noneCst);
5881+ Value flippedCumulativeProduct = rewriter.create <AtenFlipOp>(
5882+ loc, shapeType, cumulativeProduct, makeOneElementList (constantZero));
5883+
5884+ // strides = torch.cat([strides[1:-1], torch.tensor([1])])
5885+ auto oneTensorType = ValueTensorType::get (rewriter.getContext (),
5886+ SmallVector<int64_t >{1 }, intType);
5887+ Value oneTensor = rewriter.create <AtenScalarTensorOp>(
5888+ loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst,
5889+ noneCst);
5890+
5891+ Value strides;
5892+ if (inputRank > 1 ) {
5893+ // strides[1:-1]
5894+ auto slicedStrideType = Torch::ValueTensorType::get (
5895+ rewriter.getContext (), SmallVector<int64_t >{inputRank - 1 }, // sizes
5896+ intType);
5897+ Value strideSliceEnd = rewriter.create <ConstantIntOp>(
5898+ loc, rewriter.getI64IntegerAttr (inputRank));
5899+ Value slicedStrides = rewriter.create <AtenSliceTensorOp>(
5900+ loc, slicedStrideType, /* self*/ flippedCumulativeProduct,
5901+ /* dim*/ constantZero,
5902+ /* start=*/ constantOne, /* end=*/ strideSliceEnd, /* step=*/ constantOne);
5903+ // torch.cat
5904+ auto tensorListElementType = Torch::ValueTensorType::get (
5905+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize }, intType);
5906+ Value tensorList = rewriter.create <Torch::PrimListConstructOp>(
5907+ loc, Torch::ListType::get (tensorListElementType),
5908+ SmallVector<Value>{slicedStrides, oneTensor});
5909+ strides = rewriter.create <Torch::AtenCatOp>(loc, shapeType, tensorList,
5910+ constantZero);
5911+ } else {
5912+ // strides[1:-1] is empty
5913+ strides = oneTensor;
5914+ }
5915+
5916+ // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
5917+ // input_shape_tensor
5918+ auto unsqueezedResultType = ValueTensorType::get (
5919+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize , 1 }, intType);
5920+ Value unsqueezedResult = rewriter.create <AtenUnsqueezeOp>(
5921+ loc, unsqueezedResultType, slicedResult, constantOne);
5922+
5923+ auto unsqueezedStridesType = ValueTensorType::get (
5924+ rewriter.getContext (), SmallVector<int64_t >{1 , inputRank}, intType);
5925+ Value unsqueezedStrides = rewriter.create <AtenUnsqueezeOp>(
5926+ loc, unsqueezedStridesType, strides, constantZero);
5927+
5928+ auto dividedBroadcastType = ValueTensorType::get (
5929+ rewriter.getContext (), SmallVector<int64_t >{kUnknownSize , inputRank},
5930+ intType);
5931+ Value divided = rewriter.create <AtenFloorDivideOp>(
5932+ loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);
5933+
5934+ Value modded = rewriter.create <AtenRemainderTensorOp>(
5935+ loc, resultType, divided, inputShapeTensor);
5936+
5937+ rewriter.replaceOp (op, modded);
5938+ return success ();
5939+ }
5940+ };
5941+
57085942// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
57095943namespace {
57105944class DecomposeAtenAddmmOp : public OpRewritePattern <AtenAddmmOp> {
@@ -11263,6 +11497,7 @@ class DecomposeComplexOpsPass
1126311497 addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
1126411498 patterns);
1126511499 addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
11500+ addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
1126611501 addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
1126711502 addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
1126811503 addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
0 commit comments