Skip to content

Commit 0f15e7d

Browse files
committed
addressing cat empty tensor case.Fixes gpt2 data distributed example
1 parent 2e0f502 commit 0f15e7d

File tree

4 files changed

+67
-3
lines changed

4 files changed

+67
-3
lines changed

examples/distributed_inference/data_parallel_stable_diffusion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,5 @@
5353

5454
# Assume there are 2 processes (2 devices)
5555
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
56-
print("before \n")
5756
result = pipe(prompt).images[0]
58-
print("after ")
5957
result.save(f"result_{distributed_state.process_index}.png")

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,17 @@ def aten_ops_native_group_norm(
218218
)
219219

220220

221-
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
221+
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
222+
# Validate only one user, which is a getitem node that accesses the first element in the list
223+
for each_input in node.args[0]:
224+
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
225+
return False
226+
return True
227+
228+
229+
@dynamo_tensorrt_converter(
230+
torch.ops.aten.cat.default, supports_dynamic_shapes=True, validator=cat_validator
231+
)
222232
def aten_ops_cat(
223233
ctx: ConversionContext,
224234
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
set_layer_name,
1515
)
1616

17+
logger = logging.getLogger(__name__)
18+
1719

1820
def unify_and_concat_trt_tensors(
1921
ctx: ConversionContext,

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,60 @@ def forward(self, x, y, z):
2525
inputs,
2626
)
2727

28+
@parameterized.expand(
29+
[
30+
("pos", 0),
31+
("neg", -3),
32+
]
33+
)
34+
def test_cat_with_scalar_inputs(self, _, dim):
35+
# Ensure scalar tensor wrap works
36+
class Cat(nn.Module):
37+
def forward(self, x, y):
38+
# y is a scalar, x is a tensor
39+
return torch.ops.aten.cat.default((x, y), dim)
40+
41+
x = torch.randn(1, 2, 3, device="cuda")
42+
y = torch.ones_like(x) * 5.0 # simulate scalar broadcast
43+
inputs = [x, y]
44+
self.run_test(Cat(), inputs)
45+
46+
@parameterized.expand(
47+
[
48+
("pos", 0),
49+
("neg", -3),
50+
]
51+
)
52+
def test_cat_with_empty_tensor(self, _, dim):
53+
# Handle empty tensor in concat
54+
class Cat(nn.Module):
55+
def forward(self, x):
56+
y = torch.empty(0, 2, 3, device="cuda")
57+
return torch.ops.aten.cat.default((x, y), dim)
58+
59+
inputs = [
60+
torch.randn(1, 2, 3, device="cuda"),
61+
]
62+
self.run_test(Cat(), inputs)
63+
64+
@parameterized.expand(
65+
[
66+
("pos", 2),
67+
("neg", -1),
68+
]
69+
)
70+
def test_cat_with_different_dtypes(self, _, dim):
71+
# check dtype promotion path in concat
72+
class Cat(nn.Module):
73+
def forward(self, x, y):
74+
return torch.ops.aten.cat.default((x, y), dim)
75+
76+
inputs = [
77+
torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"),
78+
torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"),
79+
]
80+
self.run_test(Cat(), inputs)
81+
2882
@parameterized.expand(
2983
[
3084
("pos", 1),

0 commit comments

Comments
 (0)