@@ -591,13 +591,11 @@ def _save_weight_mapping(self) -> None:
591
591
torch .cuda .empty_cache ()
592
592
593
593
@needs_refit # type: ignore[misc]
594
- def _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) -> None :
594
+ def _insert_engine_to_cache (self , hash_val : str , engine : bytes ) -> None :
595
+ serialized_engine = engine .serialize ()
595
596
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
596
597
# if not self.compilation_settings.strip_engine_weights:
597
598
# # set EXCLUDE_WEIGHTS flag to strip weights
598
- # runtime = trt.Runtime(TRT_LOGGER)
599
- # engine = runtime.deserialize_cuda_engine(serialized_engine)
600
-
601
599
# serialization_config = engine.create_serialization_config()
602
600
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
603
601
# serialized_engine = engine.serialize_with_config(
@@ -731,10 +729,6 @@ def run(
731
729
if interpreter_result is not None : # hit the cache
732
730
return interpreter_result # type: ignore[no-any-return]
733
731
734
- import psutil
735
-
736
- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
737
- # breakpoint()
738
732
self ._construct_trt_network_def ()
739
733
740
734
if not self .compilation_settings .immutable_weights :
@@ -753,14 +747,11 @@ def run(
753
747
self ._create_timing_cache (
754
748
builder_config , self .compilation_settings .timing_cache_path
755
749
)
756
- import psutil
757
-
758
- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
759
- # breakpoint()
760
750
761
751
cuda_engine = self .builder .build_engine_with_config (
762
752
self .ctx .net , builder_config
763
753
)
754
+ assert cuda_engine
764
755
765
756
_LOGGER .info (
766
757
f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
@@ -772,17 +763,13 @@ def run(
772
763
)
773
764
774
765
# Engine caching only for refittable engines
775
- # if (
776
- # not self.compilation_settings.immutable_weights
777
- # and self.compilation_settings.cache_built_engines
778
- # and self.engine_cache is not None
779
- # ):
780
- # self._insert_engine_to_cache(hash_val, serialized_engine)
781
-
782
- print ("After build_engine_with_config" )
783
- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
784
- # breakpoint()
785
- assert cuda_engine
766
+ if (
767
+ not self .compilation_settings .immutable_weights
768
+ and self .compilation_settings .cache_built_engines
769
+ and self .engine_cache is not None
770
+ ):
771
+ self ._insert_engine_to_cache (hash_val , cuda_engine )
772
+
786
773
if self .compilation_settings .use_python_runtime :
787
774
return TRTInterpreterResult (
788
775
cuda_engine ,
@@ -792,16 +779,13 @@ def run(
792
779
self .ctx .requires_output_allocator ,
793
780
)
794
781
else :
795
- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
796
- # breakpoint()
797
782
serialized_engine = cuda_engine .serialize ()
798
783
_LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
799
784
800
785
with io .BytesIO () as engine_bytes :
801
786
engine_bytes .write (serialized_engine )
802
787
engine_str = engine_bytes .getvalue ()
803
- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
804
- # breakpoint()
788
+
805
789
return TRTInterpreterResult (
806
790
engine_str ,
807
791
self ._input_names ,
0 commit comments