Skip to content

Commit 66bcb93

Browse files
committed
undoing the empty changes, keeping the cat different case handling
1 parent e6fc22b commit 66bcb93

File tree

3 files changed

+0
-53
lines changed

3 files changed

+0
-53
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,19 +249,9 @@ def parse_cat_args(
249249
return input_tensors, dim
250250

251251

252-
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
253-
# empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
254-
inputs, _ = parse_cat_args(node.args, node.kwargs)
255-
for each_input in inputs:
256-
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
257-
return False
258-
return True
259-
260-
261252
@dynamo_tensorrt_converter(
262253
torch.ops.aten.cat.default,
263254
supports_dynamic_shapes=True,
264-
capability_validator=cat_validator,
265255
)
266256
def aten_ops_cat(
267257
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ def cat(
3030
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3131
trt_inputs = []
3232
for i, each_input in enumerate(input):
33-
if isinstance(each_input, torch.Tensor) and each_input.numel() == 0:
34-
logger.warning(
35-
f"Warning: empty tensor in cat input {i}, replacing with zeros"
36-
)
37-
# ITensor with same condition leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
38-
# hence the validator
39-
continue
4033
if not isinstance(each_input, TRTTensor):
4134
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
4235
if cast_dtype:

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -60,42 +60,6 @@ def forward(self, x, y):
6060
inputs = [x, y]
6161
self.run_test(Cat(), inputs)
6262

63-
@parameterized.expand(
64-
[
65-
("pos", 0),
66-
("neg", -3),
67-
]
68-
)
69-
def test_cat_with_empty_tensor(self, _, dim):
70-
# Handle empty tensor in concat
71-
class Cat(nn.Module):
72-
def forward(self, x):
73-
y = torch.empty(0, 2, 3, device="cuda")
74-
return torch.ops.aten.cat.default((x, y), dim)
75-
76-
inputs = [
77-
torch.randn(1, 2, 3, device="cuda"),
78-
]
79-
self.run_test(Cat(), inputs)
80-
81-
@parameterized.expand(
82-
[
83-
("pos", 2),
84-
("neg", -1),
85-
]
86-
)
87-
def test_cat_with_different_dtypes(self, _, dim):
88-
# check dtype promotion path in concat
89-
class Cat(nn.Module):
90-
def forward(self, x, y):
91-
return torch.ops.aten.cat.default((x, y), dim)
92-
93-
inputs = [
94-
torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"),
95-
torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"),
96-
]
97-
self.run_test(Cat(), inputs)
98-
9963
@parameterized.expand(
10064
[
10165
("pos", 1),

0 commit comments

Comments
 (0)