|
| 1 | +import logging |
1 | 2 | import warnings
|
| 3 | +from datetime import datetime |
2 | 4 | from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
|
3 | 5 |
|
4 | 6 | import numpy
|
|
15 | 17 | from .input_tensor_spec import InputTensorSpec
|
16 | 18 | from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
|
17 | 19 |
|
| 20 | + |
| 21 | +_LOGGER: logging.Logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | + |
18 | 24 | TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
|
19 | 25 | Callable[[torch.fx.GraphModule], None]
|
20 | 26 | ] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
|
@@ -179,7 +185,12 @@ def run(
|
179 | 185 | warnings.warn("Current platform doesn't support fast native fp16!")
|
180 | 186 |
|
181 | 187 | self.input_specs_iter = 0
|
| 188 | + run_module_start_time = datetime.now() |
182 | 189 | super().run()
|
| 190 | + _LOGGER.info( |
| 191 | + f"Run Module elapsed time: {datetime.now() - run_module_start_time}" |
| 192 | + ) |
| 193 | + build_engine_start_time = datetime.now() |
183 | 194 |
|
184 | 195 | self.builder.max_batch_size = max_batch_size
|
185 | 196 | builder_config = self.builder.create_builder_config()
|
@@ -227,6 +238,9 @@ def run(
|
227 | 238 | if builder_config.get_timing_cache()
|
228 | 239 | else bytearray()
|
229 | 240 | )
|
| 241 | + _LOGGER.info( |
| 242 | + f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" |
| 243 | + ) |
230 | 244 |
|
231 | 245 | return TRTInterpreterResult(
|
232 | 246 | engine, self._input_names, self._output_names, serialized_cache
|
|
0 commit comments