Skip to content

Commit a7205e4

Browse files
committed
Merge branch 'main' into copilot/add-unload-model-option
2 parents ca14c5c + 65efc3d commit a7205e4

File tree

18 files changed

+24573
-5354
lines changed

18 files changed

+24573
-5354
lines changed

docs/RELEASE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ The publish jobs will not run if any of the previous jobs fail.
8989

9090
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
9191

92-
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
92+
Both jobs require a @lstein or @blessedcoolant to approve them from the workflow's **Summary** tab.
9393

9494
- Click the **Review deployments** button
9595
- Select the environment (either `testpypi` or `pypi` - typically you select both)
@@ -101,7 +101,7 @@ Both jobs require a @hipsterusername or @psychedelicious to approve them from th
101101

102102
Check the [python infrastructure status page] for incidents.
103103

104-
If there are no incidents, contact @hipsterusername or @lstein, who have owner access to GH and PyPI, to see if access has expired or something like that.
104+
If there are no incidents, contact @lstein or @blessedcoolant, who have owner access to GH and PyPI, to see if access has expired or something like that.
105105

106106
#### `publish-testpypi` Job
107107

invokeai/app/invocations/fields.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ class ZImageConditioningField(BaseModel):
333333
"""A Z-Image conditioning tensor primitive value"""
334334

335335
conditioning_name: str = Field(description="The name of conditioning tensor")
336+
mask: Optional[TensorField] = Field(
337+
default=None,
338+
description="The mask associated with this conditioning tensor for regional prompting. "
339+
"Excluded regions should be set to False, included regions should be set to True.",
340+
)
336341

337342

338343
class ConditioningField(BaseModel):

invokeai/app/invocations/z_image_denoise.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,29 @@
3232
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
3333
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
3434
from invokeai.backend.util.devices import TorchDevice
35+
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
36+
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
3537
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
3638
from invokeai.backend.z_image.z_image_controlnet_extension import (
3739
ZImageControlNetExtension,
3840
z_image_forward_with_control,
3941
)
42+
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
4043

4144

4245
@invocation(
4346
"z_image_denoise",
4447
title="Denoise - Z-Image",
4548
tags=["image", "z-image"],
4649
category="image",
47-
version="1.1.0",
50+
version="1.2.0",
4851
classification=Classification.Prototype,
4952
)
5053
class ZImageDenoiseInvocation(BaseInvocation):
51-
"""Run the denoising process with a Z-Image model."""
54+
"""Run the denoising process with a Z-Image model.
55+
56+
Supports regional prompting by connecting multiple conditioning inputs with masks.
57+
"""
5258

5359
# If latents is provided, this means we are doing image-to-image.
5460
latents: Optional[LatentsField] = InputField(
@@ -63,10 +69,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
6369
transformer: TransformerField = InputField(
6470
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
6571
)
66-
positive_conditioning: ZImageConditioningField = InputField(
72+
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
6773
description=FieldDescriptions.positive_cond, input=Input.Connection
6874
)
69-
negative_conditioning: Optional[ZImageConditioningField] = InputField(
75+
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
7076
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
7177
)
7278
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
@@ -126,25 +132,50 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
126132
def _load_text_conditioning(
127133
self,
128134
context: InvocationContext,
129-
conditioning_name: str,
135+
cond_field: ZImageConditioningField | list[ZImageConditioningField],
136+
img_height: int,
137+
img_width: int,
130138
dtype: torch.dtype,
131139
device: torch.device,
132-
) -> torch.Tensor:
133-
"""Load Z-Image text conditioning."""
134-
cond_data = context.conditioning.load(conditioning_name)
135-
if len(cond_data.conditionings) != 1:
136-
raise ValueError(
137-
f"Expected exactly 1 conditioning entry for Z-Image, got {len(cond_data.conditionings)}. "
138-
"Ensure you are using the Z-Image text encoder."
139-
)
140-
z_image_conditioning = cond_data.conditionings[0]
141-
if not isinstance(z_image_conditioning, ZImageConditioningInfo):
142-
raise TypeError(
143-
f"Expected ZImageConditioningInfo, got {type(z_image_conditioning).__name__}. "
144-
"Ensure you are using the Z-Image text encoder."
145-
)
146-
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
147-
return z_image_conditioning.prompt_embeds
140+
) -> list[ZImageTextConditioning]:
141+
"""Load Z-Image text conditioning with optional regional masks.
142+
143+
Args:
144+
context: The invocation context.
145+
cond_field: Single conditioning field or list of fields.
146+
img_height: Height of the image token grid (H // patch_size).
147+
img_width: Width of the image token grid (W // patch_size).
148+
dtype: Target dtype.
149+
device: Target device.
150+
151+
Returns:
152+
List of ZImageTextConditioning objects with embeddings and masks.
153+
"""
154+
# Normalize to a list
155+
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
156+
157+
text_conditionings: list[ZImageTextConditioning] = []
158+
for cond in cond_list:
159+
# Load the text embeddings
160+
cond_data = context.conditioning.load(cond.conditioning_name)
161+
assert len(cond_data.conditionings) == 1
162+
z_image_conditioning = cond_data.conditionings[0]
163+
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
164+
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
165+
prompt_embeds = z_image_conditioning.prompt_embeds
166+
167+
# Load the mask, if provided
168+
mask: torch.Tensor | None = None
169+
if cond.mask is not None:
170+
mask = context.tensors.load(cond.mask.tensor_name)
171+
mask = mask.to(device=device)
172+
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
173+
mask, img_height, img_width, dtype, device
174+
)
175+
176+
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
177+
178+
return text_conditionings
148179

149180
def _get_noise(
150181
self,
@@ -221,14 +252,33 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
221252

222253
transformer_info = context.models.load(self.transformer.transformer)
223254

224-
# Load positive conditioning
225-
pos_prompt_embeds = self._load_text_conditioning(
255+
# Calculate image token grid dimensions
256+
patch_size = 2 # Z-Image uses patch_size=2
257+
latent_height = self.height // LATENT_SCALE_FACTOR
258+
latent_width = self.width // LATENT_SCALE_FACTOR
259+
img_token_height = latent_height // patch_size
260+
img_token_width = latent_width // patch_size
261+
img_seq_len = img_token_height * img_token_width
262+
263+
# Load positive conditioning with regional masks
264+
pos_text_conditionings = self._load_text_conditioning(
226265
context=context,
227-
conditioning_name=self.positive_conditioning.conditioning_name,
266+
cond_field=self.positive_conditioning,
267+
img_height=img_token_height,
268+
img_width=img_token_width,
228269
dtype=inference_dtype,
229270
device=device,
230271
)
231272

273+
# Create regional prompting extension
274+
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
275+
text_conditionings=pos_text_conditionings,
276+
img_seq_len=img_seq_len,
277+
)
278+
279+
# Get the concatenated prompt embeddings for the transformer
280+
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
281+
232282
# Load negative conditioning if provided and guidance_scale != 1.0
233283
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
234284
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
@@ -238,21 +288,22 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
238288
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
239289
)
240290
if do_classifier_free_guidance:
241-
if self.negative_conditioning is None:
242-
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
243-
neg_prompt_embeds = self._load_text_conditioning(
291+
assert self.negative_conditioning is not None
292+
# Load all negative conditionings and concatenate embeddings
293+
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
294+
neg_text_conditionings = self._load_text_conditioning(
244295
context=context,
245-
conditioning_name=self.negative_conditioning.conditioning_name,
296+
cond_field=self.negative_conditioning,
297+
img_height=img_token_height,
298+
img_width=img_token_width,
246299
dtype=inference_dtype,
247300
device=device,
248301
)
249-
250-
# Calculate image sequence length for timestep shifting
251-
patch_size = 2 # Z-Image uses patch_size=2
252-
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
302+
# Concatenate all negative embeddings
303+
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
253304

254305
# Calculate shift based on image sequence length
255-
mu = self._calculate_shift(image_seq_len)
306+
mu = self._calculate_shift(img_seq_len)
256307

257308
# Generate sigma schedule with time shift
258309
sigmas = self._get_sigmas(mu, self.steps)
@@ -443,6 +494,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
443494
)
444495
)
445496

497+
# Apply regional prompting patch if we have regional masks
498+
exit_stack.enter_context(
499+
patch_transformer_for_regional_prompting(
500+
transformer=transformer,
501+
regional_attn_mask=regional_extension.regional_attn_mask,
502+
img_seq_len=img_seq_len,
503+
)
504+
)
505+
446506
# Denoising loop
447507
for step_idx in tqdm(range(total_steps)):
448508
sigma_curr = sigmas[step_idx]

invokeai/app/invocations/z_image_text_encoder.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from contextlib import ExitStack
2-
from typing import Iterator, Tuple
2+
from typing import Iterator, Optional, Tuple
33

44
import torch
55
from transformers import PreTrainedModel, PreTrainedTokenizerBase
66

77
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
8-
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
8+
from invokeai.app.invocations.fields import (
9+
FieldDescriptions,
10+
Input,
11+
InputField,
12+
TensorField,
13+
UIComponent,
14+
ZImageConditioningField,
15+
)
916
from invokeai.app.invocations.model import Qwen3EncoderField
1017
from invokeai.app.invocations.primitives import ZImageConditioningOutput
1118
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -27,25 +34,34 @@
2734
title="Prompt - Z-Image",
2835
tags=["prompt", "conditioning", "z-image"],
2936
category="conditioning",
30-
version="1.0.0",
37+
version="1.1.0",
3138
classification=Classification.Prototype,
3239
)
3340
class ZImageTextEncoderInvocation(BaseInvocation):
34-
"""Encodes and preps a prompt for a Z-Image image."""
41+
"""Encodes and preps a prompt for a Z-Image image.
42+
43+
Supports regional prompting by connecting a mask input.
44+
"""
3545

3646
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
3747
qwen3_encoder: Qwen3EncoderField = InputField(
3848
title="Qwen3 Encoder",
3949
description=FieldDescriptions.qwen3_encoder,
4050
input=Input.Connection,
4151
)
52+
mask: Optional[TensorField] = InputField(
53+
default=None,
54+
description="A mask defining the region that this conditioning prompt applies to.",
55+
)
4256

4357
@torch.no_grad()
4458
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
4559
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
4660
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
4761
conditioning_name = context.conditioning.save(conditioning_data)
48-
return ZImageConditioningOutput.build(conditioning_name)
62+
return ZImageConditioningOutput(
63+
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
64+
)
4965

5066
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
5167
"""Encode prompt using Qwen3 text encoder.

invokeai/app/services/model_install/model_install_default.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Model installation class."""
22

3+
import gc
34
import locale
45
import os
56
import re
7+
import sys
68
import threading
79
import time
810
from copy import deepcopy
@@ -187,6 +189,22 @@ def register_path(
187189
config.source_type = ModelSourceType.Path
188190
return self._register(model_path, config)
189191

192+
# TODO: Replace this with a proper fix for underlying problem of Windows holding open
193+
# the file when it needs to be moved.
194+
@staticmethod
195+
def _move_with_retries(src: Path, dst: Path, attempts: int = 5, delay: float = 0.5) -> None:
196+
"""Workaround for Windows file-handle issues when moving files."""
197+
for tries_left in range(attempts, 0, -1):
198+
try:
199+
move(src, dst)
200+
return
201+
except PermissionError:
202+
gc.collect()
203+
if tries_left == 1:
204+
raise
205+
time.sleep(delay)
206+
delay *= 2 # Exponential backoff
207+
190208
def install_path(
191209
self,
192210
model_path: Union[Path, str],
@@ -205,7 +223,7 @@ def install_path(
205223
dest_dir.mkdir(parents=True)
206224
dest_path = dest_dir / model_path.name if model_path.is_file() else dest_dir
207225
if model_path.is_file():
208-
move(model_path, dest_path)
226+
self._move_with_retries(model_path, dest_path) # Windows workaround TODO: fix root cause
209227
elif model_path.is_dir():
210228
# Move the contents of the directory, not the directory itself
211229
for item in model_path.iterdir():
@@ -500,6 +518,39 @@ def _start_installer_thread(self) -> None:
500518
self._install_thread.start()
501519
self._running = True
502520

521+
@staticmethod
522+
def _safe_rmtree(path: Path, logger: Any) -> None:
523+
"""Remove a directory tree with retry logic for Windows file locking issues.
524+
525+
On Windows, memory-mapped files may not be immediately released even after
526+
the file handle is closed. This function retries the removal with garbage
527+
collection to help release any lingering references.
528+
"""
529+
max_retries = 3
530+
retry_delay = 0.5 # seconds
531+
532+
for attempt in range(max_retries):
533+
try:
534+
# Force garbage collection to release any lingering file references
535+
gc.collect()
536+
rmtree(path)
537+
return
538+
except PermissionError as e:
539+
if attempt < max_retries - 1 and sys.platform == "win32":
540+
logger.warning(
541+
f"Failed to remove {path} (attempt {attempt + 1}/{max_retries}): {e}. "
542+
f"Retrying in {retry_delay}s..."
543+
)
544+
time.sleep(retry_delay)
545+
retry_delay *= 2 # Exponential backoff
546+
else:
547+
logger.error(f"Failed to remove temporary directory {path}: {e}")
548+
# On final failure, don't raise - the temp dir will be cleaned up on next startup
549+
return
550+
except Exception as e:
551+
logger.error(f"Unexpected error removing {path}: {e}")
552+
return
553+
503554
def _install_next_item(self) -> None:
504555
self._logger.debug(f"Installer thread {threading.get_ident()} starting")
505556
while True:
@@ -529,7 +580,7 @@ def _install_next_item(self) -> None:
529580
finally:
530581
# if this is an install of a remote file, then clean up the temporary directory
531582
if job._install_tmpdir is not None:
532-
rmtree(job._install_tmpdir)
583+
self._safe_rmtree(job._install_tmpdir, self._logger)
533584
self._install_completed_event.set()
534585
self._install_queue.task_done()
535586
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
@@ -574,7 +625,7 @@ def _remove_dangling_install_dirs(self) -> None:
574625
path = self._app_config.models_path
575626
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
576627
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
577-
rmtree(tmpdir)
628+
self._safe_rmtree(tmpdir, self._logger)
578629

579630
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
580631
"""Scan the models directory for missing models and return a list of them."""

0 commit comments

Comments
 (0)