Skip to content

Commit 4086fdc

Browse files
author
Wei Wei
committed
[fx2trt] support for new_ones, new_empty, as_strided, einsum (#80)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/80 1. add support for new_ones, new_empty. new_empty is filled with uninitialized data so there is no test for it. 2. add support for as_strided,einsum 3. add some print information for interpreter.run and inference time 4. add support for einsum 5. add some print information on testing to show compile time and run time. 6. fix a bug in where 7. acc_tracer need to recompile after DCE Reviewed By: yinghai Differential Revision: D36460857 fbshipit-source-id: f74a7df8d5b11c1a9478cb9840a7c47577c3bdc0
1 parent 0d2668a commit 4086fdc

File tree

10 files changed

+360
-11
lines changed

10 files changed

+360
-11
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,7 +2483,10 @@ def acc_ops_where(
24832483

24842484
if type(x_t) != TRTTensor:
24852485
if x_shape != output_shape:
2486-
x_t.expand(output_shape)
2486+
# special case where 1 element in x_t
2487+
if len(x_t.shape) == 0:
2488+
x_t = x_t.unsqueeze(0)
2489+
x_t = x_t.expand(output_shape)
24872490
x_val = get_trt_tensor(network, x_t, f"{name}_x")
24882491
else:
24892492
x_val = x_t
@@ -2498,7 +2501,10 @@ def acc_ops_where(
24982501

24992502
if type(y_t) != TRTTensor:
25002503
if y_shape != output_shape:
2501-
y_t.expand(output_shape)
2504+
# special case where 1 element in y_t
2505+
if len(y_t.shape) == 0:
2506+
y_t = y_t.unsqueeze(0)
2507+
y_t = y_t.expand(output_shape)
25022508
y_val = get_trt_tensor(network, y_t, f"{name}_y")
25032509
else:
25042510
y_val = y_t
@@ -2912,16 +2918,20 @@ def acc_ops_cat(
29122918
name: str,
29132919
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29142920
tensors = kwargs["tensors"]
2921+
dim = kwargs["dim"]
29152922

29162923
if any(not isinstance(t, TRTTensor) for t in tensors): # type: ignore[union-attr]
29172924
raise RuntimeError(
29182925
f"cat received inputs {tensors} that is not part " "of the TensorRT region!"
29192926
)
2920-
29212927
layer = network.add_concatenation(inputs=tensors)
2922-
layer.axis = cast(int, kwargs["dim"]) - (
2923-
1 if network.has_implicit_batch_dimension else 0
2924-
)
2928+
if dim < 0:
2929+
if network.has_implicit_batch_dimension:
2930+
dim = len(tensors[0].shape) + 1 + dim
2931+
else:
2932+
dim = len(tensors[0].shape) + dim
2933+
2934+
layer.axis = dim - (1 if network.has_implicit_batch_dimension else 0)
29252935
set_layer_name(layer, target, name)
29262936
return layer.get_output(0)
29272937

@@ -3477,3 +3487,129 @@ def acc_ops_interpolate(
34773487

34783488
set_layer_name(layer, target, name)
34793489
return layer.get_output(0)
3490+
3491+
3492+
@tensorrt_converter(acc_ops.new_ones)
3493+
def acc_ops_new_ones(
3494+
network: TRTNetwork,
3495+
target: Target,
3496+
args: Tuple[Argument, ...],
3497+
kwargs: Dict[str, Argument],
3498+
name: str,
3499+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3500+
input_val = kwargs["input"]
3501+
size_val = kwargs["size"]
3502+
dtype_val = kwargs.get("dtype")
3503+
if dtype_val is None:
3504+
dtype_val = input_val.dtype
3505+
dtype_val = torch_dtype_from_trt(dtype_val)
3506+
3507+
device_val = kwargs.get("device")
3508+
assert (
3509+
device_val == "cuda" or device_val == None
3510+
), f"device is not `cuda` but {device_val}"
3511+
3512+
weight = torch.ones(size_val, dtype=dtype_val)
3513+
return get_trt_tensor(network, weight, f"{name}_weight")
3514+
3515+
3516+
@tensorrt_converter(acc_ops.new_empty)
3517+
def acc_ops_new_empty(
3518+
network: TRTNetwork,
3519+
target: Target,
3520+
args: Tuple[Argument, ...],
3521+
kwargs: Dict[str, Argument],
3522+
name: str,
3523+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3524+
input_val = kwargs["input"]
3525+
size_val = kwargs["size"]
3526+
dtype_val = kwargs.get("dtype")
3527+
if dtype_val is None:
3528+
dtype_val = input_val.dtype
3529+
dtype_val = torch_dtype_from_trt(dtype_val)
3530+
3531+
device_val = kwargs.get("device")
3532+
assert (
3533+
device_val == "cuda" or device_val == None
3534+
), f"device is not `cuda` but {device_val}"
3535+
3536+
weight = torch.zeros(size_val, dtype=dtype_val)
3537+
return get_trt_tensor(network, weight, f"{name}_weight")
3538+
3539+
3540+
@tensorrt_converter(acc_ops.einsum)
3541+
def acc_ops_einsum(
3542+
network: TRTNetwork,
3543+
target: Target,
3544+
args: Tuple[Argument, ...],
3545+
kwargs: Dict[str, Argument],
3546+
name: str,
3547+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3548+
input_val = list(kwargs["operands"])
3549+
equation = kwargs["equation"]
3550+
assert type(equation) is str, "equation type is not str"
3551+
const_flag = False
3552+
for i, input_source in enumerate(input_val):
3553+
if type(input_source) == torch.Tensor:
3554+
# const change to TRTensor always output with dtype FLOAT even though stored memory is other type
3555+
# so we cast to float first. And we need other inputs to be the same float type
3556+
input_source = input_source.to(torch.float)
3557+
const_flag = True
3558+
input_val[i] = get_trt_tensor(network, input_source, name + f"_input_source{i}")
3559+
3560+
if const_flag:
3561+
for i, input_source in enumerate(input_val):
3562+
if input_source.dtype != trt.float32:
3563+
input_val[i] = type_cast(
3564+
network, target, f"{name}_input_cast{i}", input_source, trt.float32
3565+
)
3566+
einsum_layer = network.add_einsum(inputs=input_val, equation=equation)
3567+
return einsum_layer.get_output(0)
3568+
3569+
3570+
@tensorrt_converter(acc_ops.as_strided)
3571+
def acc_ops_as_strided(
3572+
network: TRTNetwork,
3573+
target: Target,
3574+
args: Tuple[Argument, ...],
3575+
kwargs: Dict[str, Argument],
3576+
name: str,
3577+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3578+
input_val = kwargs["input"]
3579+
size = kwargs["size"]
3580+
stride = kwargs["stride"]
3581+
offset = kwargs.get("storage_offset")
3582+
if offset == None:
3583+
offset = 0
3584+
3585+
# convert to 1d vector
3586+
new_kwargs = {}
3587+
new_kwargs["input"] = kwargs["input"]
3588+
new_kwargs["start_dim"] = 0
3589+
new_kwargs["end_dim"] = -1
3590+
flatten_output = acc_ops_flatten(network, target, [], new_kwargs, name + "_flatten")
3591+
# use gather to collect output from 1d flatten_output
3592+
rank = len(size)
3593+
assert len(size) == len(stride), "size and stride shapes are not the same"
3594+
3595+
def nested(rank, size, stride, current, dim, indices):
3596+
if dim == rank:
3597+
indices.append(current)
3598+
return
3599+
for i in range(size[dim]):
3600+
current = current + stride[dim] * i
3601+
nested(rank, size, stride, current, dim + 1, indices)
3602+
current = current - stride[dim] * i
3603+
3604+
indices = []
3605+
nested(rank, size, stride, 0, 0, indices)
3606+
indices = torch.tensor(indices, dtype=torch.int)
3607+
indices = indices + offset
3608+
indices_tensor = get_trt_tensor(network, indices, name + "_indices_tensor")
3609+
gather_layer = network.add_gather(flatten_output, indices_tensor, axis=0)
3610+
# resize the output to match size
3611+
shuffle_layer = network.add_shuffle(gather_layer.get_output(0))
3612+
set_layer_name(shuffle_layer, target, name + "_shuffle")
3613+
shuffle_layer.reshape_dims = tuple(size)
3614+
3615+
return shuffle_layer.get_output(0)

fx/passes/lower_basic_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
356356
inp = node.args[2]
357357

358358
inp_flag = False
359-
if inp.target == operator.getitem:
359+
if type(inp) == torch.fx.node.Node and inp.target == operator.getitem:
360360
new_args = list(copy.deepcopy(inp.args[1]))
361361
for ind, val in enumerate(new_args):
362362
if type(val) == int:

fx/tools/timing_cache_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@ def __init__(self, timing_cache_prefix: str = "", save_timing_cache=False):
1212
if not timing_cache_prefix and tc:
1313
timing_cache_prefix_name = tc
1414

15-
self.timing_cache_prefix_name = timing_cache_prefix
15+
self.timing_cache_prefix_name = timing_cache_prefix_name
1616
self.save_timing_cache = save_timing_cache
1717

1818
def get_file_full_name(self, name: str):
1919
return f"{self.timing_cache_prefix_name}_{name}.npy"
2020

2121
def get_timing_cache_trt(self, timing_cache_file: str) -> bytearray:
2222
timing_cache_file = self.get_file_full_name(timing_cache_file)
23-
with open(timing_cache_file, "rb") as raw_cache:
24-
cache_data = raw_cache.read()
25-
return bytearray(cache_data)
23+
try:
24+
with open(timing_cache_file, "rb") as raw_cache:
25+
cache_data = raw_cache.read()
26+
return bytearray(cache_data)
27+
except Exception:
28+
return None
2629

2730
def update_timing_cache(
2831
self, timing_cache_file: str, serilized_cache: bytearray
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
2+
import torch
3+
import torch.nn as nn
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_fx2trt import AccTestCase
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
9+
class TestConverter(AccTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_v1", (5, 5), (2, 3), (1, 2), 0),
13+
("2d_dim_v2", (5, 5), (2, 3), (2, 2), 1),
14+
("3d_dim_v1", (20, 20), (2, 3, 2), (2, 2, 2), 0),
15+
# take long time on large dimensions, we do not have better implementation yet
16+
# ("4d_dim_v1", (200, 200, 200, 200), (9, 9, 3, 2), (2, 2, 2, 3), 0),
17+
# ("4d_dim_v2", (200, 200, 200, 200), (1, 15, 512, 1), (4096, 256, 1, 1), 0),
18+
]
19+
)
20+
def test_as_strided(self, _, x_size, size, stride, offset):
21+
class Stride(nn.Module):
22+
def forward(self, x):
23+
return torch.as_strided(x, size, stride, offset)
24+
25+
inputs = [torch.randn(*x_size)]
26+
self.run_test(
27+
Stride(),
28+
inputs,
29+
expected_ops={acc_ops.as_strided},
30+
test_implicit_batch_dim=False,
31+
)
32+
33+
34+
if __name__ == "__main__":
35+
run_tests()

test/converters/acc_op/test_cat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ def forward(self, x, y, z):
1414
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
1515
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
1616

17+
def test_cat_neg(self):
18+
class Cat(nn.Module):
19+
def forward(self, x, y, z):
20+
return torch.cat((x, y, z), -1)
21+
22+
inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)]
23+
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
24+
1725
def test_cat_with_dynamic_shape(self):
1826
class Cat(nn.Module):
1927
def forward(self, x, y):

test/converters/acc_op/test_einsum.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
2+
import torch
3+
import torch.nn as nn
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_fx2trt import AccTestCase
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
9+
class TestConverter(AccTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim", "ij,jk->ik", (2, 3), (3, 4)),
13+
("2d_dim_ext", "ij,kj->ik", (2, 3), (4, 3)),
14+
("3d_dim", "cxd,cyd->cxy", (3, 4, 5), (3, 6, 5)),
15+
("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)),
16+
("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)),
17+
# TRT does not support ellipsis or diagonal operations
18+
]
19+
)
20+
def test_einsum(self, _, equation, x_size, y_size):
21+
class Einsum(nn.Module):
22+
def forward(self, x, y):
23+
return torch.einsum(equation, x, y)
24+
25+
inputs = [torch.randn(*x_size), torch.randn(*y_size)]
26+
self.run_test(
27+
Einsum(),
28+
inputs,
29+
expected_ops={acc_ops.einsum},
30+
test_implicit_batch_dim=False,
31+
)
32+
33+
34+
if __name__ == "__main__":
35+
run_tests()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
2+
import torch
3+
import torch.nn as nn
4+
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
5+
from torch.testing._internal.common_utils import run_tests
6+
7+
8+
class TestNewOnesConverter(AccTestCase):
9+
def test_newone(self):
10+
class TestModule(nn.Module):
11+
def forward(self, x):
12+
return x.new_ones((3, 5), dtype=torch.float16)
13+
14+
inputs = [torch.randn(1, 10)]
15+
self.run_test(
16+
TestModule(),
17+
inputs,
18+
expected_ops={acc_ops.new_ones},
19+
test_implicit_batch_dim=False,
20+
)
21+
22+
def test_newone_no_dtype(self):
23+
class TestModule(nn.Module):
24+
def forward(self, x):
25+
return x.new_ones((3, 5))
26+
27+
inputs = [torch.randn(1, 10)]
28+
self.run_test(
29+
TestModule(),
30+
inputs,
31+
expected_ops={acc_ops.new_ones},
32+
test_implicit_batch_dim=False,
33+
)
34+
35+
def test_newone_device(self):
36+
class TestModule(nn.Module):
37+
def forward(self, x):
38+
return x.new_ones((3, 5), device="cuda")
39+
40+
inputs = [torch.randn(1, 10)]
41+
self.run_test(
42+
TestModule(),
43+
inputs,
44+
expected_ops={acc_ops.new_ones},
45+
test_implicit_batch_dim=False,
46+
)
47+
48+
49+
if __name__ == "__main__":
50+
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,5 +2500,9 @@ def test_all_acc_ops_registered(self):
25002500
acc_ops.isinf,
25012501
acc_ops.any,
25022502
acc_ops.tensor_split,
2503+
acc_ops.new_empty,
2504+
acc_ops.new_ones,
2505+
acc_ops.einsum,
2506+
acc_ops.as_strided,
25032507
},
25042508
)

0 commit comments

Comments
 (0)