Skip to content

Commit 1a12a5e

Browse files
committed
Empty tensor handling
1 parent dae1ead commit 1a12a5e

File tree

6 files changed

+273
-7
lines changed

6 files changed

+273
-7
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,11 @@ TRTEngine::~TRTEngine() {
249249
trt_engine_profiler.reset();
250250
exec_ctx.reset();
251251
cuda_engine.reset();
252+
for (void* ptr : empty_input_ptrs) {
253+
if (ptr)
254+
cudaFree(ptr);
255+
}
256+
empty_input_ptrs.clear();
252257
rt.reset();
253258
}
254259

core/runtime/TRTEngine.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ struct TRTEngine : torch::CustomClassHolder {
177177
bool use_pre_allocated_outputs = false;
178178
std::vector<at::Tensor> pre_allocated_outputs;
179179

180+
// Empty Input Pointers
181+
std::vector<void*> empty_input_ptrs = {};
182+
180183
// Output Allocator-Related Functionality
181184
bool requires_output_allocator = false; // engine requires output allocator
182185
bool use_output_allocator_outputs = false; // users specify to use output allocator

core/runtime/execute_engine.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ void setup_input_tensors(
129129
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
130130
"Error while setting the tensor address for shape inputs");
131131

132+
void* tensor_addr = nullptr;
132133
if (cudagraphs_enabled) {
133134
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
134135
compiled_engine->input_buffers[i] = input_cpu;
@@ -152,15 +153,23 @@ void setup_input_tensors(
152153
if (cudagraphs_enabled) {
153154
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
154155
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
155-
TORCHTRT_CHECK(
156-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
157-
"Error while setting the input tensor address for inputs");
156+
tensor_addr = compiled_engine->input_buffers[i].data_ptr();
158157
} else {
159158
// Otherwise use the formatted buffer directly
160-
TORCHTRT_CHECK(
161-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
162-
"Error while setting the input tensor address for inputs");
159+
tensor_addr = formatted_inputs.back().data_ptr();
163160
}
161+
// handle empty tensors→ TensorRT requires non-null address even if numel() = 0
162+
size_t nbytes = final_input.numel() * final_input.element_size();
163+
if (nbytes == 0 || tensor_addr == nullptr) {
164+
void* dummy = nullptr;
165+
cudaMalloc(&dummy, 1); // allocate 1 byte GPU buffer to satisfy TRT and get a non-null address
166+
tensor_addr = dummy;
167+
compiled_engine->empty_input_ptrs.push_back(dummy); // track to free later
168+
}
169+
170+
TORCHTRT_CHECK(
171+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), tensor_addr),
172+
"Failed to bind tensor address for " << name);
164173
}
165174
}
166175
}

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,57 @@ def aten_ops_native_group_norm(
217217
)
218218

219219

220+
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
221+
"""
222+
Validator for torch.cat operation with empty tensor handling.
223+
224+
PyTorch allows torch.tensor([]) (shape (0,)) to be concatenated with higher-dimensional
225+
tensors, but TensorRT requires all inputs to have the same rank. This validator catches
226+
this specific edge case.
227+
228+
Example valid case: cat([(3, 4), (0, 4)], dim=0) - same rank, properly shaped empty tensor for TRT
229+
Example invalid case: cat([(3, 4), (0,)], dim=0) - torch.tensor([]) with rank mismatch
230+
"""
231+
inputs = node.args[0]
232+
233+
if len(inputs) < 2:
234+
return True
235+
236+
# Collect metadata for all inputs
237+
input_metas = []
238+
for inp in inputs:
239+
if isinstance(inp, TRTTensor):
240+
# TRTTensor has shape directly
241+
input_metas.append(inp.shape)
242+
else:
243+
# For nodes, get metadata
244+
meta = getattr(inp, "meta", {}).get("tensor_meta")
245+
if meta is None:
246+
# Can't validate without metadata, allow it
247+
return True
248+
shape = tuple(meta.shape)
249+
input_metas.append(shape)
250+
251+
# Check for the specific problematic case:
252+
# 1D empty tensor (0,) being concatenated with higher-dimensional tensors
253+
ranks = [len(shape) for shape in input_metas]
254+
# If all ranks are the same, it's fine (PyTorch and TensorRT both handle this)
255+
if len(set(ranks)) == 1:
256+
return True
257+
# If ranks differ, check if we have a 1D empty tensor (0,) in the mix
258+
# This is the torch.tensor([]) case that PyTorch allows but TensorRT doesn't
259+
for i, shape in enumerate(input_metas):
260+
if shape == (0,) or (len(shape) == 1 and shape[0] == 0):
261+
# Found a 1D empty tensor with rank mismatch
262+
_LOGGER.debug(
263+
f"Concatenation rejected by TRT, torch.tensor([]) or 1D empty tensor at position {i} "
264+
f"PyTorch allows this but TensorRT requires all inputs to have the same rank. "
265+
f"Use torch.empty((0, ...)) with explicit dimensions matching other inputs instead. Falling back to Pytorch"
266+
)
267+
return False
268+
return True
269+
270+
220271
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
221272
def aten_ops_cat(
222273
ctx: ConversionContext,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,26 @@ def setup_input_tensors(
392392
self.context.set_input_shape(
393393
input_name, tuple(contiguous_inputs[i].shape)
394394
)
395+
tensor_to_bind = contiguous_inputs[i]
396+
if tensor_to_bind.numel() == 0:
397+
# this is used to provide valid memory address to TRT
398+
dummy = torch.empty(
399+
1,
400+
dtype=tensor_to_bind.dtype,
401+
device=torch.cuda.current_device(),
402+
)
403+
tensor_to_bind = dummy
404+
if not hasattr(self, "_empty_input_buffers"):
405+
self._empty_input_buffers = []
406+
self._empty_input_buffers.append(dummy)
395407
if cudagraphs_enabled:
396408
self._input_buffers[i].copy_(contiguous_inputs[i])
397409
self.context.set_tensor_address(
398410
input_name, self._input_buffers[i].data_ptr()
399411
)
400412
else:
401413
self.context.set_tensor_address(
402-
input_name, contiguous_inputs[i].data_ptr()
414+
input_name, tensor_to_bind.data_ptr()
403415
)
404416

405417
def create_output_tensors(self) -> List[torch.Tensor]:
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
import torch_tensorrt as torchtrt
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import TestCase, run_tests
7+
8+
DECIMALS_OF_AGREEMENT = 5 # for output comparison
9+
10+
11+
# We provide non null address to TRT
12+
class ConcatEmptyModel(nn.Module):
13+
def __init__(self, dim=0):
14+
super().__init__()
15+
self.dim = dim
16+
17+
def forward(self, x, y):
18+
return torch.cat([x, y], dim=self.dim)
19+
20+
21+
# TRT will handle
22+
class ConcatEmptyModelEmptyConstant(nn.Module):
23+
def __init__(self, dim=0):
24+
super().__init__()
25+
self.dim = dim
26+
27+
def forward(self, x):
28+
y = torch.empty((0, 4), dtype=torch.float).cuda()
29+
return torch.cat([x, y], dim=self.dim)
30+
31+
32+
# makes use of validator
33+
class ConcatEmptyModelEmptyConstantMisMatchDim(nn.Module):
34+
def __init__(self, dim=0):
35+
super().__init__()
36+
self.dim = dim
37+
38+
def forward(self, x):
39+
y = torch.tensor([], device="cuda")
40+
return torch.cat([x, y], dim=self.dim)
41+
42+
43+
class TestConcatEmptyTensor(TestCase):
44+
45+
@parameterized.expand(
46+
[
47+
(
48+
"python_runtime_model_one_empty_0",
49+
True,
50+
ConcatEmptyModel,
51+
"two_inputs",
52+
(0,),
53+
),
54+
(
55+
"cpp_runtime_model_one_empty_0",
56+
False,
57+
ConcatEmptyModel,
58+
"two_inputs",
59+
(0,),
60+
),
61+
(
62+
"python_runtime_model_one_empty_0_4",
63+
True,
64+
ConcatEmptyModel,
65+
"two_inputs",
66+
(0, 4),
67+
),
68+
(
69+
"cpp_runtime_model_one_empty_0_4",
70+
False,
71+
ConcatEmptyModel,
72+
"two_inputs",
73+
(0, 4),
74+
),
75+
(
76+
"python_runtime_model_two_empty_0_4",
77+
True,
78+
ConcatEmptyModelEmptyConstant,
79+
"one_input",
80+
(0, 4),
81+
),
82+
(
83+
"cpp_runtime_model_two_empty_0_4",
84+
False,
85+
ConcatEmptyModelEmptyConstant,
86+
"one_input",
87+
(0, 4),
88+
),
89+
(
90+
"python_runtime_model_three_empty_0",
91+
True,
92+
ConcatEmptyModelEmptyConstantMisMatchDim,
93+
"one_input",
94+
(0,),
95+
),
96+
(
97+
"cpp_runtime_model_three_empty_0",
98+
False,
99+
ConcatEmptyModelEmptyConstantMisMatchDim,
100+
"one_input",
101+
(0,),
102+
),
103+
]
104+
)
105+
def test_concat_empty_with_nonempty(
106+
self, _, use_python_runtime, model_class, input_type, empty_shape
107+
):
108+
"""
109+
Test concatenation of empty tensor with non-empty tensor
110+
along a specific dimension using Torch-TensorRT compiled model.
111+
"""
112+
# Create model
113+
model = model_class(dim=0).eval().cuda()
114+
115+
# Inputs: prepare based on model requirements
116+
empty_input = torch.empty(empty_shape, dtype=torch.float).cuda()
117+
non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda()
118+
119+
if input_type == "two_inputs":
120+
inputs = [empty_input, non_empty_input]
121+
else: # one_input
122+
inputs = [non_empty_input]
123+
124+
# Compile with Torch-TensorRT
125+
compiled_model = torchtrt.compile(
126+
model,
127+
"dynamo",
128+
inputs,
129+
min_block_size=5,
130+
use_python_runtime=use_python_runtime,
131+
)
132+
133+
# Run reference model
134+
ref_out = model(*inputs)
135+
# Run compiled model
136+
trt_out = compiled_model(*inputs)
137+
138+
# Assertions
139+
self.assertEqual(ref_out.shape, trt_out.shape)
140+
self.assertAlmostEqual(
141+
float(torch.max(torch.abs(ref_out - trt_out))),
142+
0,
143+
DECIMALS_OF_AGREEMENT,
144+
msg="Concat with empty tensor output mismatch",
145+
)
146+
147+
@parameterized.expand(
148+
[
149+
("python_runtime_empty_0", True, (0,)),
150+
("cpp_runtime_empty_0", False, (0,)),
151+
("python_runtime_empty_0_4", True, (0, 4)),
152+
("cpp_runtime_empty_0_4", False, (0, 4)),
153+
]
154+
)
155+
def test_concat_nonempty_with_empty(self, _, use_python_runtime, empty_shape):
156+
"""
157+
Concatenate non-empty tensor with empty tensor (opposite order)
158+
"""
159+
model = ConcatEmptyModel(dim=0).eval().cuda()
160+
161+
non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda()
162+
empty_input = torch.empty(empty_shape, dtype=torch.float).cuda()
163+
inputs = [non_empty_input, empty_input]
164+
165+
compiled_model = torchtrt.compile(
166+
model,
167+
"dynamo",
168+
inputs,
169+
min_block_size=5,
170+
use_python_runtime=use_python_runtime,
171+
)
172+
173+
ref_out = model(*inputs)
174+
trt_out = compiled_model(*inputs)
175+
176+
self.assertEqual(ref_out.shape, trt_out.shape)
177+
self.assertAlmostEqual(
178+
float(torch.max(torch.abs(ref_out - trt_out))),
179+
0,
180+
DECIMALS_OF_AGREEMENT,
181+
msg="Concat with empty tensor (opposite order) output mismatch",
182+
)
183+
184+
185+
if __name__ == "__main__":
186+
run_tests()

0 commit comments

Comments
 (0)