@@ -4639,6 +4639,25 @@ class ConvertAtenIndexTensorOpNone
46394639 }
46404640};
46414641
4642+ Value wrapNegativeIndices (Value index, int maxIndex, Operation *op,
4643+ ConversionPatternRewriter &rewriter) {
4644+
4645+ auto zeroValue = tosa::getConstTensor<int32_t >(rewriter, op, 0 , {}).value ();
4646+ auto maxIndexValue =
4647+ tosa::getConstTensor<int32_t >(rewriter, op, maxIndex, {}).value ();
4648+
4649+ auto indexType = dyn_cast<RankedTensorType>(index.getType ());
4650+
4651+ auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
4652+ rewriter, op->getLoc (), indexType, maxIndexValue, index);
4653+ auto boolType = indexType.clone (rewriter.getIntegerType (1 ));
4654+ auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4655+ rewriter, op->getLoc (), boolType, zeroValue, index);
4656+ return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc (),
4657+ indexType, isNegativeIndices,
4658+ wrappedIndicesOp, index);
4659+ }
4660+
46424661template <>
46434662LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
46444663 AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
@@ -4677,6 +4696,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
46774696
46784697 auto outType = getTypeConverter ()->convertType (op.getType ());
46794698
4699+ Operation *indicesTf;
4700+
46804701 // Support for multiple indexes
46814702 if (indexTensors.size () > 1 ) {
46824703 // t[i, i]
@@ -4710,6 +4731,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
47104731 index);
47114732 }
47124733
4734+ index = wrapNegativeIndices (index, inputTensorType.getShape ()[i], op,
4735+ rewriter);
47134736 // Expand last dim of index to tf indices [2,3] -> [2,3,1]
47144737 SmallVector<int64_t > indiceShapeOneDim;
47154738 for (auto shape : indexShape) {
@@ -4852,57 +4875,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
48524875 auto indicesShapeConcat = indexesShape[0 ];
48534876 uint64_t lastDim = indexesRank[0 ];
48544877 indicesShapeConcat.push_back (indicesTfConcatTensors.size ());
4855- auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
4878+ indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
48564879 rewriter, op->getLoc (),
48574880 GetTypeFromTensorShape (indicesShapeConcat, rewriter.getIntegerType (32 )),
48584881 indicesTfConcatTensors, lastDim);
48594882
4860- if (!indicesTf) {
4861- return rewriter.notifyMatchFailure (
4862- op, " Convert TorchIndex To TfIndices fail." );
4863- }
4864- // do the tf gathernp algorithm with tf style indices as input.
4865- auto result = tosa::convertGatherNdOp (rewriter, op, outType, input,
4866- indicesTf.getResult ());
4883+ } else {
48674884
4868- if (!result) {
4869- return rewriter.notifyMatchFailure (
4870- op, " Convert GatherNdOp fail for index tensor." );
4885+ // Single index
4886+ auto index = indexTensors[0 ];
4887+ auto indexType = dyn_cast<RankedTensorType>(index.getType ());
4888+ auto indexShape = indexType.getShape ();
4889+ // index i64 to i32 for tosa compatible
4890+ if (indexType.getElementType () != rewriter.getIntegerType (32 )) {
4891+ index = rewriter.create <tosa::CastOp>(
4892+ op->getLoc (),
4893+ RankedTensorType::get (indexShape, rewriter.getIntegerType (32 )),
4894+ index);
48714895 }
4872- rewriter.replaceOp (op, {result.value ()});
48734896
4874- return success ();
4875- }
4897+ index =
4898+ wrapNegativeIndices (index, inputTensorType. getShape ()[ 0 ], op, rewriter);
48764899
4877- // Support for multiple index
4878- auto index = indexTensors[0 ];
4879- auto indexType = dyn_cast<RankedTensorType>(index.getType ());
4880- auto indexShape = indexType.getShape ();
4881- // index i64 to i32 for tosa compatible
4882- if (indexType.getElementType () != rewriter.getIntegerType (32 )) {
4883- index = rewriter.create <tosa::CastOp>(
4884- op->getLoc (),
4885- RankedTensorType::get (indexShape, rewriter.getIntegerType (32 )), index);
4886- }
4887-
4888- // Expand last dim of index to tf indices [2,3] -> [2,3,1]
4889- SmallVector<int64_t > indicesShape;
4890- for (auto shape : indexShape) {
4891- indicesShape.push_back (shape);
4900+ // Expand last dim of index to tf indices [2,3] -> [2,3,1]
4901+ SmallVector<int64_t > indicesShape;
4902+ for (auto shape : indexShape) {
4903+ indicesShape.push_back (shape);
4904+ }
4905+ indicesShape.push_back (1 );
4906+ indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4907+ rewriter, op->getLoc (),
4908+ RankedTensorType::get (indicesShape, rewriter.getIntegerType (32 )), index,
4909+ rewriter.getDenseI64ArrayAttr (indicesShape));
48924910 }
4893- indicesShape.push_back (1 );
4894- auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4895- rewriter, op->getLoc (),
4896- RankedTensorType::get (indicesShape, rewriter.getIntegerType (32 )), index,
4897- rewriter.getDenseI64ArrayAttr (indicesShape));
48984911
48994912 if (!indicesTf) {
49004913 return rewriter.notifyMatchFailure (op,
49014914 " Convert TorchIndex To TfIndices fail." );
49024915 }
49034916 // do the tf gathernp algorithm with tf style indices as input.
49044917 auto result = tosa::convertGatherNdOp (rewriter, op, outType, input,
4905- indicesTf. getResult ());
4918+ indicesTf-> getResult (0 ));
49064919
49074920 if (!result) {
49084921 return rewriter.notifyMatchFailure (
0 commit comments