Skip to content

Commit 61092ba

Browse files
authored
feat: Use a global timing cache and add a save option (#2898)
1 parent a7c50b0 commit 61092ba

File tree

4 files changed

+47
-20
lines changed

4 files changed

+47
-20
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def compile(
7878
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
7979
dryrun: bool = _defaults.DRYRUN,
8080
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
81+
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
8182
**kwargs: Any,
8283
) -> torch.fx.GraphModule:
8384
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -137,6 +138,7 @@ def compile(
137138
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
138139
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
139140
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
141+
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
140142
**kwargs: Any,
141143
Returns:
142144
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -220,6 +222,7 @@ def compile(
220222
"dla_global_dram_size": dla_global_dram_size,
221223
"dryrun": dryrun,
222224
"hardware_compatible": hardware_compatible,
225+
"timing_cache_path": timing_cache_path,
223226
}
224227

225228
settings = CompilationSettings(**compilation_options)
@@ -477,6 +480,7 @@ def convert_module_to_trt_engine(
477480
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
478481
calibrator: object = None,
479482
allow_shape_tensors: bool = False,
483+
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
480484
**kwargs: Any,
481485
) -> bytes:
482486
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -532,7 +536,7 @@ def convert_module_to_trt_engine(
532536
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
533537
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
534538
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
535-
539+
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
536540
Returns:
537541
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
538542
"""
@@ -585,6 +589,7 @@ def convert_module_to_trt_engine(
585589
"dla_sram_size": dla_sram_size,
586590
"dla_local_dram_size": dla_local_dram_size,
587591
"dla_global_dram_size": dla_global_dram_size,
592+
"timing_cache_path": timing_cache_path,
588593
}
589594

590595
exported_program = pre_export_lowering(exported_program, torch_inputs)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import tempfile
3+
14
import torch
25
from torch_tensorrt._Device import Device
36
from torch_tensorrt._enums import EngineCapability, dtype
@@ -28,6 +31,7 @@
2831
DRYRUN = False
2932
HARDWARE_COMPATIBLE = False
3033
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
34+
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
3135

3236

3337
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
REFIT,
2525
REQUIRE_FULL_COMPILATION,
2626
SPARSE_WEIGHTS,
27+
TIMING_CACHE_PATH,
2728
TRUNCATE_DOUBLE,
2829
USE_FAST_PARTITIONER,
2930
USE_PYTHON_RUNTIME,
@@ -71,6 +72,7 @@ class CompilationSettings:
7172
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
7273
ouptut to a file if a string path is specified
7374
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
75+
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
7476
"""
7577

7678
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -101,3 +103,4 @@ class CompilationSettings:
101103
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
102104
dryrun: Union[bool, str] = DRYRUN
103105
hardware_compatible: bool = HARDWARE_COMPATIBLE
106+
timing_cache_path: str = TIMING_CACHE_PATH

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
import warnings
34
from datetime import datetime
45
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
@@ -44,7 +45,6 @@ class TRTInterpreterResult(NamedTuple):
4445
engine: Any
4546
input_names: Sequence[str]
4647
output_names: Sequence[str]
47-
serialized_cache: bytearray
4848

4949

5050
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@@ -276,30 +276,43 @@ def _populate_trt_builder_config(
276276
def _create_timing_cache(
277277
self,
278278
builder_config: trt.IBuilderConfig,
279-
existing_cache: Optional[trt.ITimingCache] = None,
280-
) -> trt.ITimingCache:
281-
cache = None
282-
if existing_cache:
283-
cache_file = np.array(existing_cache)
284-
cache = builder_config.create_timing_cache(cache_file.tobytes())
285-
else:
286-
cache = builder_config.create_timing_cache(b"")
279+
timing_cache_path: str = "",
280+
) -> None:
281+
"""
282+
Create a timing cache to enable faster build time for TRT engines.
283+
By default the timing_cache_path="/tmp/timing_cache.bin"
284+
"""
285+
buffer = b""
286+
if os.path.isfile(timing_cache_path):
287+
# Load from existing cache
288+
with open(timing_cache_path, mode="rb") as timing_cache_file:
289+
buffer = timing_cache_file.read()
290+
cache = builder_config.create_timing_cache(buffer)
287291
builder_config.set_timing_cache(cache, False)
288-
return cache
292+
293+
def _save_timing_cache(
294+
self,
295+
builder_config: trt.IBuilderConfig,
296+
timing_cache_path: str,
297+
) -> None:
298+
"""
299+
This is called after a TensorRT engine is built. Save the timing cache
300+
"""
301+
timing_cache = builder_config.get_timing_cache()
302+
with open(timing_cache_path, "wb") as timing_cache_file:
303+
timing_cache_file.write(memoryview(timing_cache.serialize()))
289304

290305
def run(
291306
self,
292307
strict_type_constraints: bool = False,
293308
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
294-
existing_cache: Optional[trt.ITimingCache] = None,
295309
tactic_sources: Optional[int] = None,
296310
) -> TRTInterpreterResult:
297311
"""
298312
Build TensorRT engine with some configs.
299313
Args:
300314
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
301315
algorithm_selector: set up algorithm selection for certain layer
302-
existing_cache: enable timing cache for TensorRT
303316
Return:
304317
TRTInterpreterResult
305318
"""
@@ -316,25 +329,27 @@ def run(
316329
builder_config = self._populate_trt_builder_config(
317330
strict_type_constraints, algorithm_selector, tactic_sources
318331
)
319-
timing_cache = self._create_timing_cache(builder_config, existing_cache)
332+
333+
self._create_timing_cache(
334+
builder_config, self.compilation_settings.timing_cache_path
335+
)
320336

321337
serialized_engine = self.builder.build_serialized_network(
322338
self.ctx.net, builder_config
323339
)
324340
assert serialized_engine
325341

326-
serialized_cache = (
327-
bytearray(timing_cache.serialize())
328-
if builder_config.get_timing_cache()
329-
else bytearray()
330-
)
331342
_LOGGER.info(
332343
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
333344
)
334345
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
335346

347+
self._save_timing_cache(
348+
builder_config, self.compilation_settings.timing_cache_path
349+
)
350+
336351
return TRTInterpreterResult(
337-
serialized_engine, self._input_names, self._output_names, serialized_cache
352+
serialized_engine, self._input_names, self._output_names
338353
)
339354

340355
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

0 commit comments

Comments
 (0)