@@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):
65
65
66
66
67
67
class TRTInterpreterResult (NamedTuple ):
68
- serialized_engine : bytes
68
+ engine : trt . ICudaEngine | bytes
69
69
input_names : Sequence [str ]
70
70
output_names : Sequence [str ]
71
71
weight_name_map : Optional [dict [Any , Any ]]
@@ -731,6 +731,10 @@ def run(
731
731
if interpreter_result is not None : # hit the cache
732
732
return interpreter_result # type: ignore[no-any-return]
733
733
734
+ import psutil
735
+
736
+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
737
+ # breakpoint()
734
738
self ._construct_trt_network_def ()
735
739
736
740
if not self .compilation_settings .immutable_weights :
@@ -749,41 +753,62 @@ def run(
749
753
self ._create_timing_cache (
750
754
builder_config , self .compilation_settings .timing_cache_path
751
755
)
752
- serialized_engine = self .builder .build_serialized_network (
756
+ import psutil
757
+
758
+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
759
+ # breakpoint()
760
+
761
+ cuda_engine = self .builder .build_engine_with_config (
753
762
self .ctx .net , builder_config
754
763
)
755
- assert serialized_engine
756
764
757
765
_LOGGER .info (
758
766
f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
759
767
)
760
- _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
761
-
762
768
self .ctx .clear_cpu_weights_reference_holder ()
763
769
764
770
self ._save_timing_cache (
765
771
builder_config , self .compilation_settings .timing_cache_path
766
772
)
767
773
768
774
# Engine caching only for refittable engines
769
- if (
770
- not self .compilation_settings .immutable_weights
771
- and self .compilation_settings .cache_built_engines
772
- and self .engine_cache is not None
773
- ):
774
- self ._insert_engine_to_cache (hash_val , serialized_engine )
775
-
776
- with io .BytesIO () as engine_bytes :
777
- engine_bytes .write (serialized_engine )
778
- engine_str = engine_bytes .getvalue ()
779
-
780
- return TRTInterpreterResult (
781
- engine_str ,
782
- self ._input_names ,
783
- self ._output_names ,
784
- self .weight_name_map ,
785
- self .ctx .requires_output_allocator ,
786
- )
775
+ # if (
776
+ # not self.compilation_settings.immutable_weights
777
+ # and self.compilation_settings.cache_built_engines
778
+ # and self.engine_cache is not None
779
+ # ):
780
+ # self._insert_engine_to_cache(hash_val, serialized_engine)
781
+
782
+ print ("After build_engine_with_config" )
783
+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
784
+ # breakpoint()
785
+ assert cuda_engine
786
+ if self .compilation_settings .use_python_runtime :
787
+ return TRTInterpreterResult (
788
+ cuda_engine ,
789
+ self ._input_names ,
790
+ self ._output_names ,
791
+ self .weight_name_map ,
792
+ self .ctx .requires_output_allocator ,
793
+ )
794
+ else :
795
+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
796
+ # breakpoint()
797
+ serialized_engine = cuda_engine .serialize ()
798
+ _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
799
+
800
+ with io .BytesIO () as engine_bytes :
801
+ engine_bytes .write (serialized_engine )
802
+ engine_str = engine_bytes .getvalue ()
803
+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
804
+ # breakpoint()
805
+ return TRTInterpreterResult (
806
+ engine_str ,
807
+ self ._input_names ,
808
+ self ._output_names ,
809
+ self .weight_name_map ,
810
+ self .ctx .requires_output_allocator ,
811
+ )
787
812
788
813
def run_node (self , n : torch .fx .Node ) -> torch .fx .Node :
789
814
self ._cur_node_name = get_node_name (n )
0 commit comments