Skip to content

Commit 82ca7f3

Browse files
nmalimbanNathan Malimban
andauthored
[tosa] error out for dynamic shape input to Tensor_hacked_twin (#4176)
Hi @sjarus, The motivation for this change is to help the internal MathWorks branch with the tosa-linalg pass. Without this change, `torch.aten.index.Tensor_hacked_twin` will fail when converting to tosa-linalg. With this change, the conversion succeeds. This change adds a check during torch-to-tosa conversion for the `torch.aten.index.Tensor_hacked_twin` operation with dynamically-shaped inputs. With this change, torch-to-tosa conversion will throw **failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal** In torch-to-tosa conversion, `torch.aten.index.Tensor_hacked_twin` lowers to a series of operations including `tosa.gather`, which does not handle dynamic shape. (`torch.gather`, which also lowers to `tosa.gather`, errors out for dynamic shape.) Without this change, a different error is thrown: ``` /mathworks/devel/sandbox/nmalimba/working/gecks/unfold/hacked_twin.mlir:12:11: error: 'tosa.reshape' op result #0 must be tosa-conformant ranked tensor of number values, but got 'tensor<1x0x1xf32>' %28 = torch.aten.index.Tensor_hacked_twin %arg0, %27 : !torch.vtensor<[?,5,3,4],f32>, !torch.list<vtensor> -> !torch.vtensor<[?,5,3,4],f32> ^ /mathworks/devel/sandbox/nmalimba/working/gecks/unfold/hacked_twin.mlir:12:11: note: see current operation: %76 = "tosa.reshape"(%0, %75) : (tensor<?x5x3x4xf32>, !tosa.shape<3>) -> tensor<1x0x1xf32> ``` This is due to `convertGatherNdOp` not properly handling dynamic shapes, which is called when lowering `Tensor_hacked_twin` to tosa. It assumes static types and introduces a 0 dimension when dynamically-shaped tensors are passed as input. Co-authored-by: Nathan Malimban <[email protected]>
1 parent 82ecfc3 commit 82ca7f3

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4485,6 +4485,12 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
44854485
return rewriter.notifyMatchFailure(
44864486
op, "Only tensor types input are currently supported");
44874487

4488+
// Dynamic shape check
4489+
if (!inputTensorType.hasStaticShape())
4490+
return rewriter.notifyMatchFailure(
4491+
op, "AtenIndexTensorHackedTwinOp: support for dynamic input "
4492+
"shape not implemented");
4493+
44884494
// Deal with torch.prim.ListConstruct of non const value to get the index
44894495
auto tensorList = op.getIndices();
44904496
SmallVector<Value> tensorsTorchType;

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,6 +2517,15 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
25172517

25182518
// -----
25192519

2520+
func.func @torch.aten.index.Tensor_hacked_twin.dynamic_size(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[?,4],f32> attributes {torch.assume_strict_symbolic_shapes} {
2521+
%0 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1,4],si64>) -> !torch.list<vtensor>
2522+
// expected-error @+1 {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}}
2523+
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list<vtensor> -> !torch.vtensor<[?,4],f32>
2524+
return %1 : !torch.vtensor<[?,4],f32>
2525+
}
2526+
2527+
// -----
2528+
25202529
// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic(
25212530
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>,
25222531
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {

0 commit comments

Comments
 (0)