Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class ConditioningInput:
control: list[ControlInput] = field(default_factory=list)
regions: list[RegionInput] = field(default_factory=list)
language: str = ""
edit_reference: bool = False # use input image as conditioning reference


class InpaintMode(Enum):
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def _find_text_encoder_models(model_list: Sequence[str]):
kind = ResourceKind.text_encoder
return {
resource_id(kind, Arch.all, te): _find_model(model_list, kind, Arch.all, te)
for te in ["clip_l", "clip_g", "t5", "qwen", "qwen_3"]
for te in ["clip_l", "clip_g", "t5", "qwen", "qwen_3_4b", "qwen_3_8b"]
}


Expand Down
132 changes: 76 additions & 56 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class Output(NamedTuple):
Input = int | float | bool | str | Output


class ConditioningOutput(NamedTuple):
positive: Output
negative: Output


class ComfyNode(NamedTuple):
id: int
type: str
Expand Down Expand Up @@ -270,8 +275,7 @@ def _add_image_hashed(self, image: Image):
def ksampler(
self,
model: Output,
positive: Output,
negative: Output,
cond: ConditioningOutput,
latent_image: Output,
sampler="dpmpp_2m_sde_gpu",
scheduler="normal",
Expand All @@ -288,8 +292,8 @@ def ksampler(
sampler_name=sampler,
scheduler=scheduler,
model=model,
positive=positive,
negative=negative,
positive=cond.positive,
negative=cond.negative,
latent_image=latent_image,
steps=steps,
cfg=cfg,
Expand All @@ -299,8 +303,7 @@ def ksampler(
def ksampler_advanced(
self,
model: Output,
positive: Output,
negative: Output,
cond: ConditioningOutput,
latent_image: Output,
sampler="dpmpp_2m_sde_gpu",
scheduler="normal",
Expand All @@ -318,8 +321,8 @@ def ksampler_advanced(
sampler_name=sampler,
scheduler=scheduler,
model=model,
positive=positive,
negative=negative,
positive=cond.positive,
negative=cond.negative,
latent_image=latent_image,
steps=steps,
start_at_step=start_at_step,
Expand All @@ -332,26 +335,28 @@ def ksampler_advanced(
def sampler_custom_advanced(
self,
model: Output,
positive: Output,
negative: Output,
cond: ConditioningOutput,
latent_image: Output,
arch: Arch,
sampler="dpmpp_2m_sde_gpu",
sampler="euler",
scheduler="normal",
steps=20,
start_at_step=0,
cfg=7.0,
seed=-1,
extent=Extent(1024, 1024),
):
self.sample_count += steps - start_at_step

if arch.is_flux_like:
positive = self.flux_guidance(positive, cfg if cfg > 1 else 3.5)
positive = self.flux_guidance(cond.positive, cfg if cfg > 1 else 3.5)
guider = self.basic_guider(model, positive)
elif cfg == 1.0:
guider = self.basic_guider(model, cond.positive)
else:
guider = self.cfg_guider(model, positive, negative, cfg)
guider = self.cfg_guider(model, cond, cfg)

sigmas = self.scheduler_sigmas(model, scheduler, steps, arch)
sigmas = self.scheduler_sigmas(model, scheduler, steps, arch, extent)
if start_at_step > 0:
_, sigmas = self.split_sigmas(sigmas, start_at_step)

Expand All @@ -366,12 +371,12 @@ def sampler_custom_advanced(
)[1]

def scheduler_sigmas(
self, model: Output, scheduler="normal", steps=20, model_version=Arch.sdxl
self, model: Output, scheduler="normal", steps=20, arch=Arch.sdxl, extent=Extent(1024, 1024)
):
if scheduler in ("align_your_steps", "ays"):
assert model_version is Arch.sd15 or model_version.is_sdxl_like
assert arch is Arch.sd15 or arch.is_sdxl_like

if model_version is Arch.sd15:
if arch is Arch.sd15:
model_type = "SD1"
else:
model_type = "SDXL"
Expand Down Expand Up @@ -410,6 +415,14 @@ def scheduler_sigmas(
mu=0.0,
beta=0.5,
)
elif scheduler == "flux2":
return self.add(
"Flux2Scheduler",
output_count=1,
steps=steps,
width=extent.width,
height=extent.height,
)
else:
return self.add(
"BasicScheduler",
Expand All @@ -431,13 +444,13 @@ def split_sigmas(self, sigmas: Output, step=0):
def basic_guider(self, model: Output, positive: Output):
return self.add("BasicGuider", 1, model=model, conditioning=positive)

def cfg_guider(self, model: Output, positive: Output, negative: Output, cfg=7.0):
def cfg_guider(self, model: Output, cond: ConditioningOutput, cfg=7.0):
return self.add(
"CFGGuider",
output_count=1,
model=model,
positive=positive,
negative=negative,
positive=cond.positive,
negative=cond.negative,
cfg=cfg,
)

Expand Down Expand Up @@ -592,7 +605,10 @@ def empty_latent_image(self, extent: Extent, arch: Arch, batch_size=1):
w, h = extent.width, extent.height
if arch.is_flux_like or arch.is_qwen_like or arch in (Arch.sd3, Arch.chroma, Arch.zimage):
return self.add("EmptySD3LatentImage", 1, width=w, height=h, batch_size=batch_size)
return self.add("EmptyLatentImage", 1, width=w, height=h, batch_size=batch_size)
if arch.is_flux2:
return self.add("EmptyFlux2LatentImage", 1, width=w, height=h, batch_size=batch_size)
else:
return self.add("EmptyLatentImage", 1, width=w, height=h, batch_size=batch_size)

def empty_latent_layers(self, extent: Extent, layer_count: int, batch_size=1):
w, h = extent.width, extent.height
Expand Down Expand Up @@ -657,16 +673,17 @@ def conditioning_zero_out(self, conditioning: Output):
return self.add("ConditioningZeroOut", 1, conditioning=conditioning)

def instruct_pix_to_pix_conditioning(
self, positive: Output, negative: Output, vae: Output, pixels: Output
self, cond: ConditioningOutput, vae: Output, pixels: Output
):
return self.add(
pos, neg, model = self.add(
"InstructPixToPixConditioning",
3,
positive=positive,
negative=negative,
positive=cond.positive,
negative=cond.negative,
vae=vae,
pixels=pixels,
)
return ConditioningOutput(pos, neg), model

def reference_latent(self, conditioning: Output, latent: Output):
return self.add("ReferenceLatent", 1, conditioning=conditioning, latent=latent)
Expand Down Expand Up @@ -712,50 +729,52 @@ def attention_mask(self, model: Output, regions: Output):

def apply_controlnet(
self,
positive: Output,
negative: Output,
cond: ConditioningOutput,
controlnet: Output,
image: Output,
vae: Output,
strength=1.0,
range: tuple[float, float] = (0.0, 1.0),
):
return self.add(
"ControlNetApplyAdvanced",
2,
positive=positive,
negative=negative,
control_net=controlnet,
image=image,
vae=vae,
strength=strength,
start_percent=range[0],
end_percent=range[1],
return ConditioningOutput(
*self.add(
"ControlNetApplyAdvanced",
2,
positive=cond.positive,
negative=cond.negative,
control_net=controlnet,
image=image,
vae=vae,
strength=strength,
start_percent=range[0],
end_percent=range[1],
)
)

def apply_controlnet_inpainting(
self,
positive: Output,
negative: Output,
cond: ConditioningOutput,
controlnet: Output,
vae: Output,
image: Output,
mask: Output,
strength=1.0,
range: tuple[float, float] = (0.0, 1.0),
):
return self.add(
"ControlNetInpaintingAliMamaApply",
2,
positive=positive,
negative=negative,
control_net=controlnet,
vae=vae,
image=image,
mask=mask,
strength=strength,
start_percent=range[0],
end_percent=range[1],
return ConditioningOutput(
*self.add(
"ControlNetInpaintingAliMamaApply",
2,
positive=cond.positive,
negative=cond.negative,
control_net=controlnet,
vae=vae,
image=image,
mask=mask,
strength=strength,
start_percent=range[0],
end_percent=range[1],
)
)

def set_controlnet_type(self, controlnet: Output, mode: ControlMode):
Expand Down Expand Up @@ -894,17 +913,18 @@ def apply_fooocus_inpaint(self, model: Output, patch: Output, latent: Output):
return self.add("INPAINT_ApplyFooocusInpaint", 1, model=model, patch=patch, latent=latent)

def vae_encode_inpaint_conditioning(
self, vae: Output, image: Output, mask: Output, positive: Output, negative: Output
self, vae: Output, image: Output, mask: Output, cond: ConditioningOutput
):
return self.add(
pos, neg, latent_inpaint, latent = self.add(
"INPAINT_VAEEncodeInpaintConditioning",
4,
vae=vae,
pixels=image,
mask=mask,
positive=positive,
negative=negative,
positive=cond.positive,
negative=cond.negative,
)
return ConditioningOutput(pos, neg), latent_inpaint, latent

def vae_encode(self, vae: Output, image: Output):
return self.add("VAEEncode", 1, vae=vae, pixels=image)
Expand Down
20 changes: 13 additions & 7 deletions ai_diffusion/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,13 @@ def to_api(self, bounds: Bounds | None = None, time: int | None = None):
if self.mode.is_lines or self.mode is ControlMode.stencil:
image.make_opaque(background=Qt.GlobalColor.white)

if self._model.arch.is_edit:
if image.extent.height > extent.height:
w = (image.extent.width * extent.height) // image.extent.height
image = Image.scale(image, Extent(w, extent.height))
elif self.mode.is_ip_adapter:
image = Image.scale(image, self.clip_vision_extent)
if self.mode.is_ip_adapter:
if self._model.arch.supports_edit:
if image.extent.height > extent.height:
w = (image.extent.width * extent.height) // image.extent.height
image = Image.scale(image, Extent(w, extent.height))
else:
image = Image.scale(image, self.clip_vision_extent)

strength = self.strength / self.strength_multiplier
return ControlInput(self.mode, image, strength, (self.start, self.end))
Expand All @@ -141,6 +142,7 @@ def _update_is_supported(self):
is_supported = True
if client := root.connection.client_if_connected:
models = client.models.for_arch(self._model.arch)

if self.mode.is_ip_adapter and models.arch in [Arch.illu, Arch.illu_v]:
resid = resource_id(ResourceKind.clip_vision, Arch.illu, "ip_adapter")
has_clip_vision = client.models.resources.get(resid, None) is not None
Expand All @@ -151,7 +153,7 @@ def _update_is_supported(self):
self.error_text = _("The server is missing the ClipVision model") + f" {search}"
is_supported = False

if self.mode.is_ip_adapter and models.arch.is_edit:
if self.mode.is_ip_adapter and models.arch.supports_edit:
is_supported = True # Reference images are merged into the conditioning context
elif self.mode.is_ip_adapter and models.ip_adapter.find(self.mode) is None:
search_path = resources.search_path(ResourceKind.ip_adapter, models.arch, self.mode)
Expand All @@ -164,6 +166,10 @@ def _update_is_supported(self):
if not client.features.ip_adapter:
self.error_text = _("IP-Adapter is not supported by this GPU")
is_supported = False
elif self.mode.is_control_net and models.arch.supports_edit:
is_supported = self.mode.can_substitute_instruction(models.arch)
if not is_supported:
self.error_text = _("Not supported for") + f" {models.arch.value}"
elif self.mode.is_control_net:
model = models.find_control(self.mode)
self.has_range = model == models.model_patch.find(self.mode, True)
Expand Down
Loading