@@ -8256,6 +8256,198 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
82568256 return success ();
82578257}
82588258
8259+ // Legalization for aten.unfold
8260+ template <>
8261+ LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
8262+ AtenUnfoldOp op, OpAdaptor adaptor,
8263+ ConversionPatternRewriter &rewriter) const {
8264+ // Approach: Use GatherOp to retrieve target elements from target dim and then
8265+ // reshape the output into slices according to the output shape
8266+ //
8267+ // Lowering steps:
8268+ // 1. Create PyTorch-style indices tensor corresponding to target elements and
8269+ // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1))
8270+ // with d_x being the dimension size of the input at dim x.
8271+ // The indices vector will be calculated using the following formula:
8272+ // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)):
8273+ // for window in range(nWindows):
8274+ // for elementIndex in range(size):
8275+ // for j in range(d_(target_dim + 1) * ... * d_(rank-1)):
8276+ // indices_vec.push_back(elementIndex + window * step)
8277+ // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices
8278+ // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
8279+ // target elements
8280+ // 4. Reshape result from above to correct output shape
8281+ auto self = adaptor.getSelf ();
8282+
8283+ auto selfType = dyn_cast<TensorType>(self.getType ());
8284+ if (!selfType)
8285+ return rewriter.notifyMatchFailure (op, " Only tensor types are supported" );
8286+
8287+ auto selfShape = selfType.getShape ();
8288+ auto selfRank = selfType.getRank ();
8289+ auto selfElemTy = selfType.getElementType ();
8290+
8291+ auto resultType =
8292+ dyn_cast<TensorType>(typeConverter->convertType (op.getType ()));
8293+ auto resultElemTy = resultType.getElementType ();
8294+
8295+ int64_t dim;
8296+ if (!matchPattern (op.getDimension (), m_TorchConstantInt (&dim)))
8297+ return rewriter.notifyMatchFailure (op,
8298+ " Only constant int dims are supported" );
8299+
8300+ int64_t size;
8301+ if (!matchPattern (op.getSize (), m_TorchConstantInt (&size)))
8302+ return rewriter.notifyMatchFailure (op,
8303+ " Only constant int sizes are supported" );
8304+
8305+ int64_t step;
8306+ if (!matchPattern (op.getStep (), m_TorchConstantInt (&step)))
8307+ return rewriter.notifyMatchFailure (op,
8308+ " Only constant int steps are supported" );
8309+
8310+ if (step <= 0 )
8311+ return rewriter.notifyMatchFailure (op, " Step value must be greater than 0" );
8312+
8313+ // Handle rank zero
8314+ if (selfRank == 0 ) {
8315+ if (dim != 0 )
8316+ return rewriter.notifyMatchFailure (
8317+ op, " Unsupported dim value for rank zero input" );
8318+
8319+ if (size != 1 )
8320+ return rewriter.notifyMatchFailure (
8321+ op, " Unsupported size value for rank zero input" );
8322+
8323+ auto result = rewriter.create <tosa::ReshapeOp>(
8324+ op->getLoc (), RankedTensorType::get ({1 }, selfElemTy), self,
8325+ rewriter.getDenseI64ArrayAttr ({1 }));
8326+
8327+ rewriter.replaceOp (op, {result.getResult ()});
8328+ return success ();
8329+ }
8330+
8331+ dim = toPositiveDim (dim, selfRank);
8332+ if (!isValidDim (dim, selfRank))
8333+ return rewriter.notifyMatchFailure (op, " Dim value is invalid" );
8334+
8335+ // Size of dimension 'dim' in the returned tensor (or number of windows within
8336+ // the dimension that got sliced)
8337+ int64_t nWindows = (selfShape[dim] - size) / step + 1 ;
8338+
8339+ // Find number of times that each base index value gets repeated for target
8340+ // dim based on dim values before and after target dim i.e. preDimAccumulate =
8341+ // d_0 * d_1 * ... * d_(target_dim - 1)
8342+ // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1)
8343+ int64_t preDimAccumulate =
8344+ std::accumulate (selfShape.begin (), selfShape.begin () + dim, 1 ,
8345+ std::multiplies<int64_t >());
8346+ int64_t postDimAccumulate =
8347+ std::accumulate (selfShape.begin () + dim + 1 , selfShape.end (), 1 ,
8348+ std::multiplies<int64_t >());
8349+
8350+ // Calculate PyTorch-style gather indices vector
8351+ // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1
8352+ // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2
8353+ // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
8354+ // 1, 1, 1, 2, 2, 2, 3, 3, 3]
8355+ // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2,
8356+ // 1, 1, 1, 2, 2, 2, 3, 3, 3,
8357+ // 0, 0, 0, 1, 1, 1, 2, 2, 2,
8358+ // 1, 1, 1, 2, 2, 2, 3, 3, 3]
8359+ SmallVector<int32_t > pyTorchIndicesBaseVec;
8360+ SmallVector<int32_t > pyTorchIndicesVec;
8361+
8362+ for (int64_t window = 0 ; window < nWindows; window++) {
8363+ for (int64_t elementIndex = 0 ; elementIndex < size; elementIndex++) {
8364+ int32_t baseIndex = static_cast <int32_t >(elementIndex + window * step);
8365+ for (int64_t i = 0 ; i < postDimAccumulate; i++)
8366+ pyTorchIndicesBaseVec.push_back (baseIndex);
8367+ }
8368+ }
8369+
8370+ for (int64_t i = 0 ; i < preDimAccumulate; i++)
8371+ pyTorchIndicesVec.insert (pyTorchIndicesVec.end (),
8372+ pyTorchIndicesBaseVec.begin (),
8373+ pyTorchIndicesBaseVec.end ());
8374+
8375+ // Create the PyTorch-style indices tensor
8376+ // Continuing with the previous example:
8377+ // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3)
8378+ // pyTorchIndices = tensor([[[0, 0, 0],
8379+ // [1, 1, 1],
8380+ // [2, 2, 2],
8381+ // [1, 1, 1],
8382+ // [2, 2, 2],
8383+ // [3, 3, 3]],
8384+ // [[0, 0, 0],
8385+ // [1, 1, 1],
8386+ // [2, 2, 2],
8387+ // [1, 1, 1],
8388+ // [2, 2, 2],
8389+ // [3, 3, 3]]])
8390+ SmallVector<int64_t > pyTorchIndicesShape (selfShape);
8391+ pyTorchIndicesShape[dim] = nWindows * size;
8392+ auto pyTorchIndices =
8393+ tosa::getConstTensor<int32_t >(rewriter, op, pyTorchIndicesVec,
8394+ pyTorchIndicesShape)
8395+ .value ();
8396+
8397+ // Convert PyTorch-style indices to TensorFlow-style indices
8398+ auto tfIndices = tosa::convertTorchIndexToTfIndices (rewriter, op, self,
8399+ pyTorchIndices, dim);
8400+ if (!tfIndices)
8401+ return rewriter.notifyMatchFailure (op,
8402+ " Convert PyTorch-style indices and dim "
8403+ " to TensorFlow-style indices failed" );
8404+
8405+ // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
8406+ // target elements
8407+ auto gatherNdOp = tosa::convertGatherNdOp (
8408+ rewriter, op, RankedTensorType::get (pyTorchIndicesShape, resultElemTy),
8409+ self, tfIndices.value ());
8410+ if (!gatherNdOp)
8411+ return rewriter.notifyMatchFailure (op, " Convert GatherNdOp failed" );
8412+
8413+ // Reshape to an intermediary shape where the gathered elements in dimension
8414+ // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size'
8415+ SmallVector<int64_t > intermediaryShape;
8416+ for (int64_t currentDim = 0 ; currentDim < selfRank; currentDim++) {
8417+ if (currentDim == dim) {
8418+ intermediaryShape.push_back (nWindows);
8419+ intermediaryShape.push_back (size);
8420+ } else {
8421+ intermediaryShape.push_back (pyTorchIndicesShape[currentDim]);
8422+ }
8423+ }
8424+
8425+ auto reshapeOp = rewriter.create <tosa::ReshapeOp>(
8426+ op->getLoc (), RankedTensorType::get (intermediaryShape, resultElemTy),
8427+ gatherNdOp.value (), rewriter.getDenseI64ArrayAttr (intermediaryShape));
8428+
8429+ // Permute dims to the correct result order
8430+ SmallVector<int32_t > permutedDims;
8431+ for (int64_t currentDim = 0 ; currentDim < selfRank + 1 ; currentDim++) {
8432+ if (currentDim != dim + 1 )
8433+ permutedDims.push_back (static_cast <int32_t >(currentDim));
8434+ }
8435+ permutedDims.push_back (static_cast <int32_t >(dim + 1 ));
8436+
8437+ auto permutedDimsConst = tosa::getConstTensor<int32_t >(
8438+ rewriter, op,
8439+ /* vec=*/ permutedDims,
8440+ /* shape=*/ {static_cast <int32_t >(selfRank + 1 )})
8441+ .value ();
8442+
8443+ auto result = rewriter.create <tosa::TransposeOp>(
8444+ op->getLoc (), resultType, reshapeOp.getResult (), permutedDimsConst);
8445+
8446+ rewriter.replaceOp (op, {result.getResult ()});
8447+
8448+ return success ();
8449+ }
8450+
82598451} // namespace
82608452
82618453// -----------------------------------------------------------------------------
@@ -8617,6 +8809,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
86178809 INSERT_ATENOP_PATTERN (AtenLog1pOp);
86188810 INSERT_ATENOP_PATTERN (AtenLog10Op);
86198811 INSERT_ATENOP_PATTERN (AtenTanOp);
8812+ INSERT_ATENOP_PATTERN (AtenUnfoldOp);
86208813#undef INSERT_ATENOP_PATTERN
86218814
86228815#define INSERT_CLONE_ATENOP_PATTERN (AtenOp ) \
0 commit comments