2424 {".engine" },
2525 )
2626
27+
2728class TQDMProgressMonitor (trt .IProgressMonitor ):
2829 def __init__ (self ):
2930 trt .IProgressMonitor .__init__ (self )
@@ -93,14 +94,18 @@ def step_complete(self, phase_name, step):
9394 except KeyboardInterrupt :
9495 # There is no need to propagate this exception to TensorRT. We can simply cancel the build.
9596 return False
96-
97+
9798
9899class TRT_MODEL_CONVERSION_BASE :
99100 def __init__ (self ):
100101 self .output_dir = folder_paths .get_output_directory ()
101102 self .temp_dir = folder_paths .get_temp_directory ()
102103 self .timing_cache_path = os .path .normpath (
103- os .path .join (os .path .join (os .path .dirname (os .path .realpath (__file__ )), "timing_cache.trt" ))
104+ os .path .join (
105+ os .path .join (
106+ os .path .dirname (os .path .realpath (__file__ )), "timing_cache.trt"
107+ )
108+ )
104109 )
105110
106111 RETURN_TYPES = ()
@@ -148,26 +153,31 @@ def _convert(
148153 context_max ,
149154 num_video_frames ,
150155 is_static : bool ,
156+ unload_before : bool = True ,
157+ unload_after : bool = True ,
151158 ):
152159 output_onnx = os .path .normpath (
153- os .path .join (
154- os .path .join (self .temp_dir , "{}" .format (time .time ())), "model.onnx"
155- )
160+ os .path .join (self .temp_dir , str (time .time ()), "model.onnx" )
156161 )
157162
158- comfy .model_management .unload_all_models ()
163+ if unload_before :
164+ comfy .model_management .unload_all_models ()
159165 comfy .model_management .load_models_gpu ([model ], force_patch_weights = True )
160166 unet = model .model .diffusion_model
161167
162168 context_dim = model .model .model_config .unet_config .get ("context_dim" , None )
163169 context_len = 77
164170 context_len_min = context_len
165171
166- if context_dim is None : #SD3
167- context_embedder_config = model .model .model_config .unet_config .get ("context_embedder_config" , None )
172+ if context_dim is None : # SD3
173+ context_embedder_config = model .model .model_config .unet_config .get (
174+ "context_embedder_config" , None
175+ )
168176 if context_embedder_config is not None :
169- context_dim = context_embedder_config .get ("params" , {}).get ("in_features" , None )
170- context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
177+ context_dim = context_embedder_config .get ("params" , {}).get (
178+ "in_features" , None
179+ )
180+ context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
171181
172182 if context_dim is not None :
173183 input_names = ["x" , "timesteps" , "context" ]
@@ -179,7 +189,7 @@ def _convert(
179189 "context" : {0 : "batch" , 1 : "num_embeds" },
180190 }
181191
182- transformer_options = model .model_options [' transformer_options' ].copy ()
192+ transformer_options = model .model_options [" transformer_options" ].copy ()
183193 if model .model .model_config .unet_config .get (
184194 "use_temporal_resblock" , False
185195 ): # SVD
@@ -205,7 +215,13 @@ def forward(self, x, timesteps, context, y):
205215 unet = svd_unet
206216 context_len_min = context_len = 1
207217 else :
218+
208219 class UNET (torch .nn .Module ):
220+ def __init__ (self , unet , opts ):
221+ super ().__init__ ()
222+ self .unet = unet
223+ self .transformer_options = opts
224+
209225 def forward (self , x , timesteps , context , y = None ):
210226 return self .unet (
211227 x ,
@@ -214,10 +230,8 @@ def forward(self, x, timesteps, context, y=None):
214230 y ,
215231 transformer_options = self .transformer_options ,
216232 )
217- _unet = UNET ()
218- _unet .unet = unet
219- _unet .transformer_options = transformer_options
220- unet = _unet
233+
234+ unet = UNET (unet , transformer_options )
221235
222236 input_channels = model .model .model_config .unet_config .get ("in_channels" )
223237
@@ -272,7 +286,8 @@ def forward(self, x, timesteps, context, y=None):
272286 dynamic_axes = dynamic_axes ,
273287 )
274288
275- comfy .model_management .unload_all_models ()
289+ if unload_after :
290+ comfy .model_management .unload_all_models ()
276291 comfy .model_management .soft_empty_cache ()
277292
278293 # TRT conversion starts here
@@ -304,7 +319,9 @@ def forward(self, x, timesteps, context, y=None):
304319 profile .set_shape (input_names [k ], min_shape , opt_shape , max_shape )
305320
306321 # Encode shapes to filename
307- encode = lambda a : "." .join (map (lambda x : str (x ), a ))
322+ def encode (a ):
323+ return "." .join (map (str , a ))
324+
308325 prefix_encode += "{}#{}#{}#{};" .format (
309326 input_names [k ], encode (min_shape ), encode (opt_shape ), encode (max_shape )
310327 )
@@ -589,6 +606,8 @@ def INPUT_TYPES(s):
589606 "step" : 1 ,
590607 },
591608 ),
609+ "unload_before" : ("BOOLEAN" , {"default" : True }),
610+ "unload_after" : ("BOOLEAN" , {"default" : True }),
592611 },
593612 }
594613
@@ -601,6 +620,8 @@ def convert(
601620 width_opt ,
602621 context_opt ,
603622 num_video_frames ,
623+ unload_before : bool = True ,
624+ unload_after : bool = True ,
604625 ):
605626 return super ()._convert (
606627 model ,
@@ -619,6 +640,8 @@ def convert(
619640 context_opt ,
620641 num_video_frames ,
621642 is_static = True ,
643+ unload_before = unload_before ,
644+ unload_after = unload_after ,
622645 )
623646
624647
0 commit comments