Skip to content

Commit e610087

Browse files
Shirong WuWei Wei
authored andcommitted
Dynamic shape for split (#71)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/71 Enable split op for dynamic shape Reviewed By: frank-wei Differential Revision: D36291482 fbshipit-source-id: 52058403a237dd9aeb55cb84adb183015b3be152
1 parent 0edba55 commit e610087

File tree

2 files changed

+82
-9
lines changed

2 files changed

+82
-9
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,17 +2361,17 @@ def acc_ops_slice_tensor(
23612361

23622362
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
23632363
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
2364-
2364+
dynamic_shape = has_dynamic_shape(input_val.shape)
23652365
if network.has_implicit_batch_dimension:
23662366
if dim == 0:
23672367
raise RuntimeError(
23682368
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
23692369
)
23702370
dim = dim - 1
23712371
else:
2372-
raise RuntimeError(
2373-
"We don't support slice_tensor with explicit batch dimension yet!"
2374-
)
2372+
if dynamic_shape:
2373+
# Check whether slice target dim is dynamic shape dim
2374+
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
23752375

23762376
start_int = cast(int, kwargs["start"])
23772377
stop_int = cast(int, kwargs["stop"])
@@ -2383,7 +2383,18 @@ def acc_ops_slice_tensor(
23832383
output_shape = list(input_val.shape)
23842384
output_shape[dim] = (stop_int - start_int) // step_int
23852385

2386-
layer = network.add_slice(input_val, start=start, shape=output_shape, stride=stride)
2386+
if dynamic_shape > 0:
2387+
output_shape = get_shape_with_dynamic_shape(
2388+
network, output_shape, input_val, target, name
2389+
)
2390+
layer = network.add_slice(
2391+
input_val,
2392+
start=start,
2393+
shape=[] if dynamic_shape else output_shape,
2394+
stride=stride,
2395+
)
2396+
if dynamic_shape:
2397+
layer.set_input(2, output_shape)
23872398
set_layer_name(layer, target, name)
23882399
return layer.get_output(0)
23892400

@@ -2584,11 +2595,14 @@ def acc_ops_split(
25842595
)
25852596

25862597
dim = cast(int, kwargs["dim"])
2598+
dynamic_shape = has_dynamic_shape(input_val.shape)
25872599
if network.has_implicit_batch_dimension:
25882600
assert dim != 0, "Can't split on batch dim when it's implicit!"
25892601
dim -= 1
25902602
else:
2591-
raise RuntimeError("We don't support split with explicit batch dimension yet!")
2603+
if dynamic_shape > 0:
2604+
# Check whether slice target dim is dynamic shape dim
2605+
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
25922606

25932607
split_size = cast(int, kwargs["split_size"])
25942608
start = [0] * len(input_val.shape)
@@ -2607,7 +2621,15 @@ def acc_ops_split(
26072621
shape = list(input_val.shape)
26082622
shape[dim] = min(split_size, cast(int, max_offset - offset))
26092623
start[dim] = offset
2610-
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
2624+
if dynamic_shape:
2625+
shape = get_shape_with_dynamic_shape(
2626+
network, shape, input_val, target, f"{name}_shape_{i}"
2627+
)
2628+
layer = network.add_slice(
2629+
input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
2630+
)
2631+
if dynamic_shape:
2632+
layer.set_input(2, shape)
26112633
offset += split_size
26122634
set_layer_name(layer, target, f"{name}_{i}")
26132635
output.append(layer.get_output(0))
@@ -2761,7 +2783,7 @@ def acc_ops_getitem(
27612783
slices = (slices,)
27622784

27632785
dynamic_shape = get_dynamic_dims(input_val.shape)
2764-
if dynamic_shape:
2786+
if len(dynamic_shape) > 0:
27652787
for i, s in zip(input_val.shape, slices):
27662788
assert i > 0 or (
27672789
s in [slice(None, None, None), slice(0, None, None), Ellipsis]

test/converters/acc_op/test_split.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44
from parameterized import parameterized
5-
from torch.testing._internal.common_fx2trt import AccTestCase
5+
from torch.testing._internal.common_fx2trt import InputTensorSpec, AccTestCase
66
from torch.testing._internal.common_utils import run_tests
77

88

@@ -48,6 +48,57 @@ def forward(self, x):
4848
test_explicit_batch_dim=False,
4949
)
5050

51+
@parameterized.expand(
52+
[
53+
("split_size", 3, 1),
54+
("sections", [5, 2, 3], 1),
55+
]
56+
)
57+
def test_split_with_dynamic_shape(self, _, split_size_or_sections, dim):
58+
class Split(nn.Module):
59+
def forward(self, x):
60+
return x.split(split_size_or_sections, dim)[0]
61+
62+
input_specs = [
63+
InputTensorSpec(
64+
shape=(-1, 10, -1),
65+
dtype=torch.float32,
66+
shape_ranges=[((1, 10, 10), (5, 10, 15), (10, 10, 20))],
67+
),
68+
]
69+
self.run_test_with_dynamic_shape(
70+
Split(),
71+
input_specs,
72+
expected_ops={
73+
acc_ops.split
74+
if isinstance(split_size_or_sections, int)
75+
else acc_ops.slice_tensor
76+
},
77+
)
78+
79+
@parameterized.expand(
80+
[
81+
("split_with_size", [2, 3, 5], 1),
82+
]
83+
)
84+
def test_split_with_size_dynamic_shape(self, _, split_size, dim):
85+
class Split(nn.Module):
86+
def forward(self, x):
87+
return x.split_with_sizes(split_size, dim)
88+
89+
input_specs = [
90+
InputTensorSpec(
91+
shape=(-1, 10, -1),
92+
dtype=torch.float32,
93+
shape_ranges=[((1, 10, 20), (5, 10, 20), (10, 10, 20))],
94+
),
95+
]
96+
self.run_test_with_dynamic_shape(
97+
Split(),
98+
input_specs,
99+
expected_ops={acc_ops.slice_tensor},
100+
)
101+
51102

52103
if __name__ == "__main__":
53104
run_tests()

0 commit comments

Comments
 (0)