-
Notifications
You must be signed in to change notification settings - Fork 370
Tentatively eliminate graph break overhead #3741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
d0ae590
56a8949
5fb0beb
0046f66
7259443
a537d9f
a9a27b1
d862b68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
|
||
import logging | ||
from contextlib import nullcontext | ||
from tempfile import tempdir | ||
from typing import Any, Dict, List, Optional, Sequence, Tuple | ||
|
||
import tensorrt as trt | ||
|
@@ -174,6 +173,8 @@ def __init__( | |
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None | ||
self._caller_stream: Optional[torch.cuda.Stream] = None | ||
self._engine_stream: Optional[torch.cuda.Stream] = None | ||
self.output_tensors: Optional[List[torch.Tensor]] = None | ||
self.sync_stream = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just inherit stream from PyTorch / input tensors |
||
|
||
# TODO: Make the below a Dictionary {shape: cudagraph} | ||
self.shape_key: Optional[str] = None | ||
|
@@ -218,9 +219,18 @@ def __init__( | |
self.requires_output_allocator = requires_output_allocator | ||
self.output_allocator: Optional[DynamicOutputAllocator] = None | ||
self.use_output_allocator_outputs = False | ||
|
||
self.device = torch.cuda.current_device() | ||
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() | ||
self.requires_unique_output = False | ||
cehongwang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if self.serialized_engine is not None and not self.settings.lazy_engine_init: | ||
self.setup_engine() | ||
self.is_shape_inference_io = [ | ||
self.engine.is_shape_inference_io(input_name) | ||
for input_name in self.input_names | ||
] | ||
|
||
def set_requires_unique_output(self, requires_unique_output: bool) -> None: | ||
self.requires_unique_output = requires_unique_output | ||
cehongwang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def get_streamable_device_memory_budget(self) -> Any: | ||
return self.engine.streamable_weights_size | ||
|
@@ -263,6 +273,15 @@ def setup_engine(self) -> None: | |
assert ( | ||
self.target_platform == Platform.current_platform() | ||
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" | ||
# Stream handling: if the caller stream is the pytorch default stream, create a new engine stream | ||
# otherwise, use the caller stream and disable stream synchronization | ||
self._caller_stream = torch.cuda.current_stream() | ||
if self._caller_stream == torch.cuda.default_stream(): | ||
self._engine_stream = torch.cuda.Stream() | ||
self.sync_stream = True | ||
else: | ||
self._engine_stream = self._caller_stream | ||
self.sync_stream = False | ||
|
||
self.initialized = True | ||
runtime = trt.Runtime(TRT_LOGGER) | ||
|
@@ -286,10 +305,14 @@ def setup_engine(self) -> None: | |
for output_name in self.output_names | ||
] | ||
self.output_shapes = [ | ||
self.engine.get_tensor_shape(output_name) | ||
tuple(self.context.get_tensor_shape(output_name)) | ||
for output_name in self.output_names | ||
] | ||
|
||
self.shape_key = "".join( | ||
str(tuple(t)).replace(" ", "") for t in self.input_shapes | ||
) | ||
|
||
if self.requires_output_allocator: | ||
self.create_output_allocator() | ||
|
||
|
@@ -355,6 +378,7 @@ def setup_input_tensors( | |
contiguous_inputs: List[torch.Tensor], | ||
cudagraphs_enabled: bool, | ||
need_cudagraphs_record: bool, | ||
shape_changed: bool = True, | ||
) -> None: | ||
for i, input_name in enumerate(self.input_names): | ||
if not contiguous_inputs[i].is_cuda: | ||
|
@@ -370,9 +394,9 @@ def setup_input_tensors( | |
+ contiguous_inputs[i + 1 :] | ||
) | ||
|
||
assert ( | ||
contiguous_inputs[i].dtype == self.input_dtypes[i] | ||
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." | ||
# assert ( | ||
# contiguous_inputs[i].dtype == self.input_dtypes[i] | ||
# ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." | ||
cehongwang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
if need_cudagraphs_record: | ||
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs | ||
|
@@ -381,16 +405,17 @@ def setup_input_tensors( | |
|
||
# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers | ||
# as per TensorRT requirements | ||
if self.engine.is_shape_inference_io(input_name): | ||
if self.is_shape_inference_io[i]: | ||
|
||
# Shape tensor inputs are casted to int64 explicitly | ||
# Currently Torch CPU pointers are not working; numpy pointers are used instead | ||
# to refer to underlying memory | ||
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() | ||
self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data) | ||
else: | ||
self.context.set_input_shape( | ||
input_name, tuple(contiguous_inputs[i].shape) | ||
) | ||
if shape_changed: | ||
self.context.set_input_shape( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we safely assume execution context holds shape between inference calls? |
||
input_name, tuple(contiguous_inputs[i].shape) | ||
) | ||
if cudagraphs_enabled: | ||
self._input_buffers[i].copy_(contiguous_inputs[i]) | ||
self.context.set_tensor_address( | ||
|
@@ -409,7 +434,7 @@ def create_output_tensors(self) -> List[torch.Tensor]: | |
output = torch.empty( | ||
size=self.output_shapes[o], | ||
dtype=self.output_dtypes[o], | ||
device=torch.cuda.current_device(), | ||
device=self.device, | ||
) | ||
outputs.append(output) | ||
return outputs | ||
|
@@ -458,7 +483,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." | ||
|
||
self.setup_input_tensors( | ||
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record | ||
contiguous_inputs, | ||
self.cudagraphs_enabled, | ||
need_cudagraphs_record, | ||
shape_changed | ||
or self.output_tensors is None, # First time execution | ||
) | ||
|
||
if shape_changed: | ||
|
@@ -480,15 +509,22 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
if can_use_pre_allocated_outputs: | ||
outputs = self.pre_allocated_outputs | ||
else: | ||
self.output_shapes = [ | ||
tuple(self.context.get_tensor_shape(output_name)) | ||
for output_name in self.output_names | ||
] | ||
if shape_changed: | ||
self.output_shapes = [ | ||
tuple(self.context.get_tensor_shape(output_name)) | ||
for output_name in self.output_names | ||
] | ||
if DYNAMIC_DIM in self.output_shapes: | ||
raise ValueError( | ||
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." | ||
) | ||
outputs = self.create_output_tensors() | ||
if ( | ||
self.output_tensors is None | ||
or self.requires_unique_output | ||
or shape_changed | ||
): | ||
self.output_tensors = self.create_output_tensors() | ||
outputs = self.output_tensors | ||
|
||
for o, output_name in enumerate(self.output_names): | ||
if need_cudagraphs_record: | ||
|
@@ -510,44 +546,39 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
if self.profiling_enabled | ||
else nullcontext() | ||
): | ||
self._caller_stream = torch.cuda.current_stream() | ||
if ( | ||
self._engine_stream == torch.cuda.default_stream() | ||
or self._engine_stream is None | ||
): | ||
self._engine_stream = torch.cuda.Stream() | ||
|
||
self._engine_stream.wait_stream(self._caller_stream) | ||
if self.sync_stream: | ||
self._engine_stream.wait_stream(self._caller_stream) | ||
cehongwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
with torch.cuda.stream(self._engine_stream): | ||
if self.cudagraphs_enabled: | ||
if need_cudagraphs_record: | ||
self.cudagraph = torch.cuda.CUDAGraph() | ||
if self.cudagraphs_enabled: | ||
if need_cudagraphs_record: | ||
self.cudagraph = torch.cuda.CUDAGraph() | ||
|
||
if self.profiling_enabled: | ||
self.cudagraph.enable_debug_mode() | ||
if self.profiling_enabled: | ||
self.cudagraph.enable_debug_mode() | ||
|
||
with torch.cuda.graph( | ||
self.cudagraph, stream=self._engine_stream | ||
): | ||
self.context.execute_async_v3( | ||
self._engine_stream.cuda_stream | ||
) | ||
with torch.cuda.graph( | ||
self.cudagraph, stream=self._engine_stream | ||
): | ||
self.context.execute_async_v3( | ||
self._engine_stream.cuda_stream | ||
) | ||
|
||
if self.profiling_enabled: | ||
import tempfile | ||
if self.profiling_enabled: | ||
import tempfile | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
self.cudagraph.debug_dump( | ||
f"{tempdir}/{self.name}_cudagraph.dot" | ||
) | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
self.cudagraph.debug_dump( | ||
f"{tmpdir}/{self.name}_cudagraph.dot" | ||
) | ||
|
||
self.cudagraph.replay() # type: ignore | ||
self.cudagraph.replay() # type: ignore | ||
|
||
else: | ||
self.context.execute_async_v3(self._engine_stream.cuda_stream) | ||
else: | ||
self.context.execute_async_v3(self._engine_stream.cuda_stream) | ||
|
||
self._caller_stream.wait_stream(self._engine_stream) | ||
if self.sync_stream: | ||
self._caller_stream.wait_stream(self._engine_stream) | ||
|
||
if self.use_pre_allocated_outputs: | ||
self.pre_allocated_outputs = self.create_output_tensors() | ||
|
@@ -646,8 +677,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: | |
|
||
return outputs | ||
|
||
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() | ||
|
||
# Run forward function | ||
contiguous_inputs: List[torch.Tensor] = [ | ||
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) | ||
|
@@ -752,13 +781,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: | |
# Representation of input shapes to a given model | ||
# Shapes are concatenated as so: | ||
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) | ||
tensor_inputs = [] | ||
for t in inputs: | ||
if not isinstance(t, torch.Tensor): | ||
return True | ||
tensor_inputs.append(t) | ||
if not all(isinstance(t, torch.Tensor) for t in inputs): | ||
return True | ||
|
||
new_shape_key = "".join( | ||
str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs | ||
str(tuple(t.shape)).replace(" ", "") | ||
for t in inputs | ||
if isinstance(t, torch.Tensor) | ||
) | ||
|
||
# If the new shape key differs from the existing one, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this going to work with serialization in C++?
Also make the name clearer like
trt_module.module_is_output_operator
ortrt_module.requires_unowned_output_tensor