Skip to content

Commit 0f2d14e

Browse files
committed
add unload_before and unload_after option, code refactor
1 parent 6bf3b54 commit 0f2d14e

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

tensorrt_convert.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
{".engine"},
2525
)
2626

27+
2728
class TQDMProgressMonitor(trt.IProgressMonitor):
2829
def __init__(self):
2930
trt.IProgressMonitor.__init__(self)
@@ -93,14 +94,18 @@ def step_complete(self, phase_name, step):
9394
except KeyboardInterrupt:
9495
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
9596
return False
96-
97+
9798

9899
class TRT_MODEL_CONVERSION_BASE:
99100
def __init__(self):
100101
self.output_dir = folder_paths.get_output_directory()
101102
self.temp_dir = folder_paths.get_temp_directory()
102103
self.timing_cache_path = os.path.normpath(
103-
os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"))
104+
os.path.join(
105+
os.path.join(
106+
os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"
107+
)
108+
)
104109
)
105110

106111
RETURN_TYPES = ()
@@ -148,26 +153,31 @@ def _convert(
148153
context_max,
149154
num_video_frames,
150155
is_static: bool,
156+
unload_before: bool = True,
157+
unload_after: bool = True,
151158
):
152159
output_onnx = os.path.normpath(
153-
os.path.join(
154-
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
155-
)
160+
os.path.join(self.temp_dir, str(time.time()), "model.onnx")
156161
)
157162

158-
comfy.model_management.unload_all_models()
163+
if unload_before:
164+
comfy.model_management.unload_all_models()
159165
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
160166
unet = model.model.diffusion_model
161167

162168
context_dim = model.model.model_config.unet_config.get("context_dim", None)
163169
context_len = 77
164170
context_len_min = context_len
165171

166-
if context_dim is None: #SD3
167-
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
172+
if context_dim is None: # SD3
173+
context_embedder_config = model.model.model_config.unet_config.get(
174+
"context_embedder_config", None
175+
)
168176
if context_embedder_config is not None:
169-
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
170-
context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
177+
context_dim = context_embedder_config.get("params", {}).get(
178+
"in_features", None
179+
)
180+
context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
171181

172182
if context_dim is not None:
173183
input_names = ["x", "timesteps", "context"]
@@ -179,7 +189,7 @@ def _convert(
179189
"context": {0: "batch", 1: "num_embeds"},
180190
}
181191

182-
transformer_options = model.model_options['transformer_options'].copy()
192+
transformer_options = model.model_options["transformer_options"].copy()
183193
if model.model.model_config.unet_config.get(
184194
"use_temporal_resblock", False
185195
): # SVD
@@ -205,7 +215,13 @@ def forward(self, x, timesteps, context, y):
205215
unet = svd_unet
206216
context_len_min = context_len = 1
207217
else:
218+
208219
class UNET(torch.nn.Module):
220+
def __init__(self, unet, opts):
221+
super().__init__()
222+
self.unet = unet
223+
self.transformer_options = opts
224+
209225
def forward(self, x, timesteps, context, y=None):
210226
return self.unet(
211227
x,
@@ -214,10 +230,8 @@ def forward(self, x, timesteps, context, y=None):
214230
y,
215231
transformer_options=self.transformer_options,
216232
)
217-
_unet = UNET()
218-
_unet.unet = unet
219-
_unet.transformer_options = transformer_options
220-
unet = _unet
233+
234+
unet = UNET(unet, transformer_options)
221235

222236
input_channels = model.model.model_config.unet_config.get("in_channels")
223237

@@ -272,7 +286,8 @@ def forward(self, x, timesteps, context, y=None):
272286
dynamic_axes=dynamic_axes,
273287
)
274288

275-
comfy.model_management.unload_all_models()
289+
if unload_after:
290+
comfy.model_management.unload_all_models()
276291
comfy.model_management.soft_empty_cache()
277292

278293
# TRT conversion starts here
@@ -304,7 +319,9 @@ def forward(self, x, timesteps, context, y=None):
304319
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)
305320

306321
# Encode shapes to filename
307-
encode = lambda a: ".".join(map(lambda x: str(x), a))
322+
def encode(a):
323+
return ".".join(map(str, a))
324+
308325
prefix_encode += "{}#{}#{}#{};".format(
309326
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
310327
)
@@ -589,6 +606,8 @@ def INPUT_TYPES(s):
589606
"step": 1,
590607
},
591608
),
609+
"unload_before": ("BOOLEAN", {"default": True}),
610+
"unload_after": ("BOOLEAN", {"default": True}),
592611
},
593612
}
594613

@@ -601,6 +620,8 @@ def convert(
601620
width_opt,
602621
context_opt,
603622
num_video_frames,
623+
unload_before: bool = True,
624+
unload_after: bool = True,
604625
):
605626
return super()._convert(
606627
model,
@@ -619,6 +640,8 @@ def convert(
619640
context_opt,
620641
num_video_frames,
621642
is_static=True,
643+
unload_before=unload_before,
644+
unload_after=unload_after,
622645
)
623646

624647

0 commit comments

Comments
 (0)