Skip to content

Commit 91f91aa

Browse files
feat(mm): prepare kontext latents before loading transformer
If the transformer fills up VRAM, then when we VAE encode kontext latents, we'll need to first offload the transformer (partially, if partial loading is enabled). No need to do this - we can encode kontext latents before loading the transformer to reduce model thrashing.
1 parent ea7868d commit 91f91aa

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,21 @@ def _run_diffusion(
328328
cfg_scale_end_step=self.cfg_scale_end_step,
329329
)
330330

331+
kontext_extension = None
332+
if self.kontext_conditioning:
333+
if not self.controlnet_vae:
334+
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
335+
336+
kontext_extension = KontextExtension(
337+
context=context,
338+
kontext_conditioning=self.kontext_conditioning
339+
if isinstance(self.kontext_conditioning, list)
340+
else [self.kontext_conditioning],
341+
vae_field=self.controlnet_vae,
342+
device=TorchDevice.choose_torch_device(),
343+
dtype=inference_dtype,
344+
)
345+
331346
with ExitStack() as exit_stack:
332347
# Prepare ControlNet extensions.
333348
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
@@ -385,21 +400,6 @@ def _run_diffusion(
385400
dtype=inference_dtype,
386401
)
387402

388-
kontext_extension = None
389-
if self.kontext_conditioning:
390-
if not self.controlnet_vae:
391-
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
392-
393-
kontext_extension = KontextExtension(
394-
context=context,
395-
kontext_conditioning=self.kontext_conditioning
396-
if isinstance(self.kontext_conditioning, list)
397-
else [self.kontext_conditioning],
398-
vae_field=self.controlnet_vae,
399-
device=TorchDevice.choose_torch_device(),
400-
dtype=inference_dtype,
401-
)
402-
403403
# Prepare Kontext conditioning if provided
404404
img_cond_seq = None
405405
img_cond_seq_ids = None

0 commit comments

Comments
 (0)