Skip to content

Commit 752cc3a

Browse files
authored
[https://nvbugs/5606166][fix] AutoDeploy: use tuples for cudagraph shape lookup (#8772)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent d2071d7 commit 752cc3a

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def capture_graph(self, *args, **kwargs):
124124
args, kwargs = self._in_spec.unflatten(inputs_truncated + args_static)
125125

126126
# capture graph for truncated inputs
127-
combined_shape = sum((input.shape for input in inputs_truncated), start=())
127+
combined_shape = sum((tuple(input.shape) for input in inputs_truncated), start=())
128128
self.graphs[combined_shape] = self._capture_one_graph(*args, **kwargs)
129129

130130
def forward(self, *args, **kwargs) -> Any:
@@ -142,7 +142,7 @@ def forward(self, *args, **kwargs) -> Any:
142142

143143
# Calculate rounded-up shapes for each input
144144
rounded_shapes = [
145-
(self.round_to_cuda_batch_size(input.shape[0]),) + input.shape[1:]
145+
(self.round_to_cuda_batch_size(input.shape[0]),) + tuple(input.shape[1:])
146146
for input in args_batched
147147
]
148148
combined_shape = sum(rounded_shapes, start=())

tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_cuda_graph_batch_sizes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
class TestCudaGraphBatchSizes:
3636
"""Test class for CUDA graph batch size handling."""
3737

38+
@staticmethod
39+
def _raise_error_for_forward(*args, **kwargs):
40+
raise RuntimeError("forward method should not be called")
41+
3842
@pytest.fixture
3943
def simple_model_and_inputs(self):
4044
"""Create a simple model and inputs for testing."""
@@ -197,7 +201,13 @@ def test_forward_uses_cuda_graph_for_valid_batch_sizes(self, simple_model_and_in
197201
test_input = data["input_tensor"][:batch_size]
198202

199203
with torch.inference_mode():
200-
output = captured_graph.forward(test_input)
204+
# temporarily remove model forward to ensure that the captured graph is used
205+
original_forward = captured_graph.model.forward
206+
captured_graph.model.forward = self._raise_error_for_forward
207+
try:
208+
output = captured_graph.forward(test_input)
209+
finally:
210+
captured_graph.model.forward = original_forward
201211

202212
# Should get valid output
203213
assert output is not None

0 commit comments

Comments
 (0)