Skip to content

Commit 5ce9346

Browse files
committed
undoing the empty changes, keeping the cat different case handling
1 parent 7f5c8aa commit 5ce9346

File tree

2 files changed

+0
-46
lines changed

2 files changed

+0
-46
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,19 +250,9 @@ def parse_cat_args(
250250
return input_tensors, dim
251251

252252

253-
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
254-
# empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
255-
inputs, _ = parse_cat_args(node.args, node.kwargs)
256-
for each_input in inputs:
257-
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
258-
return False
259-
return True
260-
261-
262253
@dynamo_tensorrt_converter(
263254
torch.ops.aten.cat.default,
264255
supports_dynamic_shapes=True,
265-
capability_validator=cat_validator,
266256
)
267257
def aten_ops_cat(
268258
ctx: ConversionContext,

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -60,42 +60,6 @@ def forward(self, x, y):
6060
inputs = [x, y]
6161
self.run_test(Cat(), inputs)
6262

63-
@parameterized.expand(
64-
[
65-
("pos", 0),
66-
("neg", -3),
67-
]
68-
)
69-
def test_cat_with_empty_tensor(self, _, dim):
70-
# Handle empty tensor in concat
71-
class Cat(nn.Module):
72-
def forward(self, x):
73-
y = torch.empty(0, 2, 3, device="cuda")
74-
return torch.ops.aten.cat.default((x, y), dim)
75-
76-
inputs = [
77-
torch.randn(1, 2, 3, device="cuda"),
78-
]
79-
self.run_test(Cat(), inputs)
80-
81-
@parameterized.expand(
82-
[
83-
("pos", 2),
84-
("neg", -1),
85-
]
86-
)
87-
def test_cat_with_different_dtypes(self, _, dim):
88-
# check dtype promotion path in concat
89-
class Cat(nn.Module):
90-
def forward(self, x, y):
91-
return torch.ops.aten.cat.default((x, y), dim)
92-
93-
inputs = [
94-
torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"),
95-
torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"),
96-
]
97-
self.run_test(Cat(), inputs)
98-
9963
@parameterized.expand(
10064
[
10165
("pos", 1),

0 commit comments

Comments
 (0)