Skip to content

Commit 9529dcc

Browse files
authored
Merge pull request #480 from Xilinx/bump_to_2b01f8b7
[AutoBump] Merge with fixes of 2b01f8b (Oct 26) (95)
2 parents 157ed54 + 5cbaae6 commit 9529dcc

File tree

3 files changed

+83
-38
lines changed

3 files changed

+83
-38
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4639,6 +4639,25 @@ class ConvertAtenIndexTensorOpNone
46394639
}
46404640
};
46414641

4642+
Value wrapNegativeIndices(Value index, int maxIndex, Operation *op,
4643+
ConversionPatternRewriter &rewriter) {
4644+
4645+
auto zeroValue = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
4646+
auto maxIndexValue =
4647+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4648+
4649+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4650+
4651+
auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
4652+
rewriter, op->getLoc(), indexType, maxIndexValue, index);
4653+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4654+
auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4655+
rewriter, op->getLoc(), boolType, zeroValue, index);
4656+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4657+
indexType, isNegativeIndices,
4658+
wrappedIndicesOp, index);
4659+
}
4660+
46424661
template <>
46434662
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
46444663
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
@@ -4677,6 +4696,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
46774696

46784697
auto outType = getTypeConverter()->convertType(op.getType());
46794698

4699+
Operation *indicesTf;
4700+
46804701
// Support for multiple indexes
46814702
if (indexTensors.size() > 1) {
46824703
// t[i, i]
@@ -4710,6 +4731,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
47104731
index);
47114732
}
47124733

4734+
index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op,
4735+
rewriter);
47134736
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
47144737
SmallVector<int64_t> indiceShapeOneDim;
47154738
for (auto shape : indexShape) {
@@ -4852,57 +4875,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
48524875
auto indicesShapeConcat = indexesShape[0];
48534876
uint64_t lastDim = indexesRank[0];
48544877
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
4855-
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
4878+
indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
48564879
rewriter, op->getLoc(),
48574880
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
48584881
indicesTfConcatTensors, lastDim);
48594882

4860-
if (!indicesTf) {
4861-
return rewriter.notifyMatchFailure(
4862-
op, "Convert TorchIndex To TfIndices fail.");
4863-
}
4864-
// do the tf gathernp algorithm with tf style indices as input.
4865-
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
4866-
indicesTf.getResult());
4883+
} else {
48674884

4868-
if (!result) {
4869-
return rewriter.notifyMatchFailure(
4870-
op, "Convert GatherNdOp fail for index tensor.");
4885+
// Single index
4886+
auto index = indexTensors[0];
4887+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4888+
auto indexShape = indexType.getShape();
4889+
// index i64 to i32 for tosa compatible
4890+
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
4891+
index = rewriter.create<tosa::CastOp>(
4892+
op->getLoc(),
4893+
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
4894+
index);
48714895
}
4872-
rewriter.replaceOp(op, {result.value()});
48734896

4874-
return success();
4875-
}
4897+
index =
4898+
wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter);
48764899

4877-
// Support for multiple index
4878-
auto index = indexTensors[0];
4879-
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4880-
auto indexShape = indexType.getShape();
4881-
// index i64 to i32 for tosa compatible
4882-
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
4883-
index = rewriter.create<tosa::CastOp>(
4884-
op->getLoc(),
4885-
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
4886-
}
4887-
4888-
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
4889-
SmallVector<int64_t> indicesShape;
4890-
for (auto shape : indexShape) {
4891-
indicesShape.push_back(shape);
4900+
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
4901+
SmallVector<int64_t> indicesShape;
4902+
for (auto shape : indexShape) {
4903+
indicesShape.push_back(shape);
4904+
}
4905+
indicesShape.push_back(1);
4906+
indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4907+
rewriter, op->getLoc(),
4908+
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
4909+
rewriter.getDenseI64ArrayAttr(indicesShape));
48924910
}
4893-
indicesShape.push_back(1);
4894-
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
4895-
rewriter, op->getLoc(),
4896-
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
4897-
rewriter.getDenseI64ArrayAttr(indicesShape));
48984911

48994912
if (!indicesTf) {
49004913
return rewriter.notifyMatchFailure(op,
49014914
"Convert TorchIndex To TfIndices fail.");
49024915
}
49034916
// do the tf gathernp algorithm with tf style indices as input.
49044917
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
4905-
indicesTf.getResult());
4918+
indicesTf->getResult(0));
49064919

49074920
if (!result) {
49084921
return rewriter.notifyMatchFailure(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,15 +1747,13 @@
17471747
"ArangeStartOutModule_basic",
17481748
"ScatterSrcStaticModule_basic",
17491749
# Runtime op verification: Out of bounds access
1750-
"IndexTensorNegativeIndexModule_basic",
17511750
"ReduceAllDimEmpty_basic",
17521751
}
17531752

17541753
FX_IMPORTER_TOSA_CRASHING_SET = {
17551754
"ScatterSrcModule_basic",
17561755
"ScatterSrcStaticModule_basic",
17571756
"HBC_basic",
1758-
"IndexTensorNegativeIndexModule_basic",
17591757
"InterpolateDynamicModule_scales_recompute_bilinear",
17601758
"InterpolateDynamicModule_sizes_bilinear",
17611759
"InterpolateDynamicModule_sizes_nearest",
@@ -2217,6 +2215,7 @@
22172215
"HardswishRandomModule_basic",
22182216
"HardtanhBackward_basic",
22192217
"IndexTensorMultiIndexStaticModule_basic",
2218+
"IndexTensorNegativeIndexModule_basic",
22202219
"IndexTensorStaticModule_basic",
22212220
"IscloseStaticModuleTrue_basic",
22222221
"IscloseStaticModule_basic",
@@ -3670,7 +3669,7 @@
36703669
"IndexPutImpl3DFloatAccumulateModule_basic",
36713670
"IndexPutImpl3DFloatNonAccumulateModule_basic",
36723671
"IndexPutImplIndexWithNoneModule_basic",
3673-
"IndexTensorNegativeIndexModule_basic",
3672+
"IndexSelectRank0IdxModule_basic",
36743673
"InterpolateDynamicModule_sizes_bilinear",
36753674
"InterpolateDynamicModule_sizes_nearest",
36763675
"InterpolateStaticModule_scales_bilinear_align_corners",
@@ -4000,6 +3999,7 @@
40003999
"GridSamplerBasic2_basic",
40014000
"GridSamplerBasic3_basic",
40024001
"GridSamplerBasic4_basic",
4002+
"IndexSelectRank0IdxModule_basic",
40034003
"IouOfModule_basic",
40044004
"MaxPool1dEmptyStrideStaticModule_basic",
40054005
"MaxPool1dStaticCeilModeTrueModule_basic",

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,3 +2373,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t
23732373
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
23742374
return %0 : !torch.vtensor<[2,3,4,4],f32>
23752375
}
2376+
2377+
// -----
2378+
2379+
// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
2380+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>,
2381+
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
2382+
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
2383+
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
2384+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor<i64>
2385+
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<i64>) -> tensor<i32>
2386+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
2387+
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
2388+
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
2389+
// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
2390+
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
2391+
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
2392+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
2393+
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
2394+
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
2395+
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
2396+
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
2397+
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
2398+
// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
2399+
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
2400+
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
2401+
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64>
2402+
2403+
func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
2404+
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
2405+
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
2406+
return %1 : !torch.vtensor<[4,2],si64>
2407+
}

0 commit comments

Comments
 (0)