Skip to content

Commit 99f04fb

Browse files
committed
Support Flux 2 edit mode with reference images
1 parent aee5d90 commit 99f04fb

File tree

11 files changed

+100
-63
lines changed

11 files changed

+100
-63
lines changed

ai_diffusion/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class ConditioningInput:
111111
control: list[ControlInput] = field(default_factory=list)
112112
regions: list[RegionInput] = field(default_factory=list)
113113
language: str = ""
114+
edit_reference: bool = False # use input image as conditioning reference
114115

115116

116117
class InpaintMode(Enum):

ai_diffusion/control.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def _update_is_supported(self):
141141
is_supported = True
142142
if client := root.connection.client_if_connected:
143143
models = client.models.for_arch(self._model.arch)
144+
144145
if self.mode.is_ip_adapter and models.arch in [Arch.illu, Arch.illu_v]:
145146
resid = resource_id(ResourceKind.clip_vision, Arch.illu, "ip_adapter")
146147
has_clip_vision = client.models.resources.get(resid, None) is not None
@@ -151,7 +152,7 @@ def _update_is_supported(self):
151152
self.error_text = _("The server is missing the ClipVision model") + f" {search}"
152153
is_supported = False
153154

154-
if self.mode.is_ip_adapter and models.arch.is_edit:
155+
if self.mode.is_ip_adapter and models.arch.supports_edit:
155156
is_supported = True # Reference images are merged into the conditioning context
156157
elif self.mode.is_ip_adapter and models.ip_adapter.find(self.mode) is None:
157158
search_path = resources.search_path(ResourceKind.ip_adapter, models.arch, self.mode)

ai_diffusion/model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,11 @@ def _generate(self, queue_mode: QueueMode):
193193

194194
def _prepare_workflow(self, dryrun=False):
195195
arch = self.arch
196-
is_edit = arch.is_edit
197196
workflow_kind = WorkflowKind.generate
198197
strength = self.strength
199198
if arch is Arch.qwen_l:
200199
strength = 1.0
201-
if strength < 1.0 or is_edit:
200+
if strength < 1.0 or self.is_editing:
202201
workflow_kind = WorkflowKind.refine
203202
client = self._connection.client
204203
image = None
@@ -409,7 +408,7 @@ def generate_live(self):
409408
def _prepare_live_workflow(self):
410409
strength = self.live.strength
411410
workflow_kind = WorkflowKind.generate
412-
if strength < 1.0 or self.arch.is_edit:
411+
if strength < 1.0 or self.is_editing:
413412
workflow_kind = WorkflowKind.refine
414413
client = self._connection.client
415414
min_mask_size = 512 if self.arch is Arch.sd15 else 800
@@ -845,11 +844,11 @@ def set_style(self, style: Style):
845844
self._style_connection = style.changed.connect(self._handle_style_changed)
846845
self.style_changed.emit(style)
847846
self.modified.emit(self, "style")
848-
self.edit_mode = self.edit_mode and self.edit_style is not None
847+
self.edit_mode = self.edit_mode and self.can_edit
849848

850849
def _handle_style_changed(self):
851850
self.style_changed.emit(self.style)
852-
self.edit_mode = self.edit_mode and self.edit_style is not None
851+
self.edit_mode = self.edit_mode and self.can_edit
853852

854853
def generate_seed(self):
855854
self.seed = workflow.generate_seed()
@@ -879,6 +878,8 @@ def add_refs(control: list[ControlInput], layer_names: list[str]):
879878
for region, r_layers in zip(cond.regions, region_layers):
880879
add_refs(region.control, r_layers)
881880

881+
cond.edit_reference = self.is_editing
882+
882883
def _performance_settings(self, client: Client):
883884
result = client.performance_settings
884885
if self.resolution_multiplier != 1.0:
@@ -950,14 +951,22 @@ def name(self):
950951
@property
951952
def edit_style(self) -> Style | None:
952953
style_arch = resolve_arch(self.style, self._connection.client_if_connected)
953-
if style_arch.is_edit:
954+
if style_arch.supports_edit:
954955
return self.style
955956
if style_id := self.style.linked_edit_style:
956957
if style := Styles.list().find(style_id):
957958
if is_style_supported(style, self._connection.client_if_connected):
958959
return style
959960
return None
960961

962+
@property
963+
def can_edit(self):
964+
return self.edit_style is not None
965+
966+
@property
967+
def is_editing(self):
968+
return self.arch.is_edit or (self.can_edit and self.edit_mode)
969+
961970

962971
class CustomInpaint(QObject, ObservableProperties):
963972
mode = Property(InpaintMode.automatic, persist=True)
@@ -1323,7 +1332,7 @@ def _prepare_input(self, canvas: Image | Extent, seed: int, time: int):
13231332
m = self._model
13241333

13251334
kind = WorkflowKind.generate
1326-
if m.strength < 1.0 or m.arch.is_edit:
1335+
if m.strength < 1.0 or m.is_editing:
13271336
kind = WorkflowKind.refine
13281337
bounds = Bounds(0, 0, *m.document.extent)
13291338
conditioning, _ = process_regions(m.regions, bounds, self._model.layers.root, time=time)

ai_diffusion/presets/samplers.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
"minimum_steps": 3,
9797
"cfg": 1.0
9898
},
99-
"Flux2 - Euler": {
99+
"Flux 2 - Euler": {
100100
"sampler": "euler",
101101
"scheduler": "flux2",
102102
"steps": 20,

ai_diffusion/resources.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class Arch(Enum):
8080
sd3 = "SD 3"
8181
flux = "Flux"
8282
flux_k = "Flux Kontext"
83-
flux2 = "Flux 2"
83+
flux2 = "Flux 2 Klein 4B"
8484
illu = "Illustrious"
8585
illu_v = "Illustrious v-prediction"
8686
chroma = "Chroma"
@@ -108,7 +108,7 @@ def from_string(string: str, model_type: str = "eps", filename: str | None = Non
108108
return Arch.flux_k
109109
if string == "flux" or string == "flux-schnell":
110110
return Arch.flux
111-
if string == "flux2":
111+
if string == "flux2" and model_type == "klein-4b":
112112
return Arch.flux2
113113
if string == "illu":
114114
return Arch.illu
@@ -188,6 +188,10 @@ def supports_cfg(self):
188188
def is_edit(self): # edit models make changes to input images
189189
return self in [Arch.flux_k, Arch.qwen_e, Arch.qwen_e_p, Arch.qwen_l]
190190

191+
@property
192+
def supports_edit(self): # includes text-to-image models that can also edit
193+
return self.is_edit or self is Arch.flux2
194+
191195
@property
192196
def is_sdxl_like(self):
193197
# illustrious technically uses sdxl architecture, but has a separate ecosystem
@@ -749,7 +753,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
749753
resource_id(ResourceKind.text_encoder, Arch.all, "clip_g"): ["clip_g"],
750754
resource_id(ResourceKind.text_encoder, Arch.all, "t5"): ["t5xxl_fp16", "t5xxl_fp8_e4m3fn", "t5xxl_fp8_e4m3fn_scaled", "t5-v1_1-xxl", "t5"],
751755
resource_id(ResourceKind.text_encoder, Arch.all, "qwen"): ["qwen_2.5_vl_7b", "qwen_2", "qwen-2", "qwen"],
752-
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3"): ["qwen_3_4b", "qwen_3", "qwen-3"],
756+
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3"): ["qwen_3_4b", "qwen3-4b", "qwen_3", "qwen-3"],
753757
resource_id(ResourceKind.vae, Arch.sd15, "default"): ["vae-ft-mse-840000-ema"],
754758
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
755759
resource_id(ResourceKind.vae, Arch.illu, "default"): ["sdxl_vae"],

ai_diffusion/styles/flux2-klein.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
"v_prediction_zsnr": false,
1717
"self_attention_guidance": false,
1818
"preferred_resolution": 0,
19-
"sampler": "Flux2 - Euler",
19+
"sampler": "Flux 2 - Euler",
2020
"sampler_steps": 4,
2121
"cfg_scale": 1.0,
22-
"live_sampler": "Flux2 - Euler",
22+
"live_sampler": "Flux 2 - Euler",
2323
"live_sampler_steps": 4,
2424
"live_cfg_scale": 1.0
2525
}

ai_diffusion/ui/generation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -900,18 +900,18 @@ def show_inpaint_menu(self):
900900
menu = self.generate_region_menu
901901
elif self.model.document.selection_bounds:
902902
menu = self.inpaint_menu
903-
menu.actions()[-2].setEnabled(self.model.edit_style is not None)
903+
menu.actions()[-2].setEnabled(self.model.can_edit)
904904
else:
905905
menu = self.generate_menu
906906
else:
907907
if self.model.region_only:
908908
menu = self.refine_region_menu
909909
elif self.model.document.selection_bounds:
910910
menu = self.refine_selection_menu
911-
menu.actions()[1].setEnabled(self.model.edit_style is not None)
911+
menu.actions()[1].setEnabled(self.model.can_edit)
912912
else:
913913
menu = self.refine_menu
914-
menu.actions()[1].setEnabled(self.model.edit_style is not None)
914+
menu.actions()[1].setEnabled(self.model.can_edit)
915915

916916
menu.setFixedWidth(width)
917917
menu.exec_(self.generate_button.mapToGlobal(pos))
@@ -943,10 +943,8 @@ def update_generate_options(self):
943943
has_regions = len(regions) > 0
944944
has_active_region = regions.is_linked(self.model.layers.active)
945945
is_region_only = has_regions and has_active_region and self.model.region_only
946-
is_edit = arch.is_edit
947-
can_switch_edit = (
948-
self.model.style.linked_edit_style != "" and self.model.edit_style is not None
949-
)
946+
is_edit = self.model.is_editing
947+
can_switch_edit = self.model.can_edit and not arch.is_edit
950948
self.region_mask_button.setVisible(has_regions)
951949
self.region_mask_button.setEnabled(has_active_region)
952950
self.region_mask_button.setIcon(_region_mask_button_icons[is_region_only])

ai_diffusion/workflow.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ class Conditioning:
397397
control: list[Control] = field(default_factory=list)
398398
regions: list[Region] = field(default_factory=list)
399399
style_prompt: str = ""
400+
edit_reference: bool = False
400401

401402
@staticmethod
402403
def from_input(i: ConditioningInput):
@@ -406,6 +407,7 @@ def from_input(i: ConditioningInput):
406407
[Control.from_input(c) for c in i.control],
407408
[Region.from_input(r, idx, i.language) for idx, r in enumerate(i.regions)],
408409
i.style,
410+
i.edit_reference,
409411
)
410412

411413
def copy(self):
@@ -415,6 +417,7 @@ def copy(self):
415417
[copy(c) for c in self.control],
416418
[r.copy() for r in self.regions],
417419
self.style_prompt,
420+
self.edit_reference,
418421
)
419422

420423
def downscale(self, original: Extent, target: Extent):
@@ -614,8 +617,8 @@ def apply_ip_adapter(
614617
models: ModelDict,
615618
mask: Output | None = None,
616619
):
617-
if models.arch.is_flux_like or models.arch.is_qwen_like:
618-
return model # No IP-adapter for Flux or Qwen, using Style model instead
620+
if not (models.arch is Arch.sd15 or models.arch.is_sdxl_like):
621+
return model
619622

620623
models = models.ip_adapter
621624

@@ -682,35 +685,39 @@ def apply_regional_ip_adapter(
682685
return model
683686

684687

685-
def apply_edit_conditioning(
688+
def apply_reference_conditioning(
686689
w: ComfyWorkflow,
687-
cond: Output,
688-
input_image: Output,
689-
input_latent: Output,
690-
control_layers: list[Control],
690+
positive: Output,
691+
input_image: Output | None,
692+
input_latent: Output | None,
693+
cond: Conditioning,
691694
vae: Output,
692695
arch: Arch,
693696
tiled_vae: bool,
694697
):
695-
if not arch.is_edit:
696-
return cond
697-
698-
extra_input = [c.image for c in control_layers if c.mode.is_ip_adapter]
699-
if len(extra_input) == 0:
700-
return w.reference_latent(cond, input_latent)
701-
702-
if arch == Arch.qwen_e_p:
703-
extra_images = [i.load(w) for i in extra_input]
704-
cond = w.reference_latent(cond, input_latent)
705-
for extra_image in extra_images:
706-
latent = vae_encode(w, vae, extra_image, tiled_vae)
707-
cond = w.reference_latent(cond, latent)
708-
return cond
709-
else:
710-
input = w.image_stitch([input_image] + [i.load(w) for i in extra_input])
711-
latent = vae_encode(w, vae, input, tiled_vae)
712-
cond = w.reference_latent(cond, latent)
713-
return cond
698+
if not arch.supports_edit:
699+
return positive
700+
701+
extra_input = (c.image for c in cond.all_control if c.mode.is_ip_adapter)
702+
extra_images = [i.load(w) for i in extra_input]
703+
match arch:
704+
case Arch.flux2 | Arch.qwen_e_p:
705+
if cond.edit_reference and input_latent:
706+
positive = w.reference_latent(positive, input_latent)
707+
for extra_image in extra_images:
708+
latent = vae_encode(w, vae, extra_image, tiled_vae)
709+
positive = w.reference_latent(positive, latent)
710+
case Arch.flux_k | Arch.qwen_e:
711+
if len(extra_images) > 0:
712+
if cond.edit_reference and input_image:
713+
extra_images.insert(0, input_image)
714+
input = w.image_stitch(extra_images)
715+
latent = vae_encode(w, vae, input, tiled_vae)
716+
positive = w.reference_latent(positive, latent)
717+
elif cond.edit_reference and input_latent:
718+
positive = w.reference_latent(positive, input_latent)
719+
720+
return positive
714721

715722

716723
def scale(
@@ -796,7 +803,9 @@ def scale_refine_and_decode(
796803
model, positive, negative = apply_control(
797804
w, model, positive, negative, cond.all_control, extent.desired, vae, models
798805
)
799-
positive = apply_edit_conditioning(w, positive, upscale, latent, [], vae, arch, tiled_vae)
806+
positive = apply_reference_conditioning(
807+
w, positive, upscale, latent, cond, vae, arch, tiled_vae
808+
)
800809
result = w.sampler_custom_advanced(model, positive, negative, latent, arch, **params)
801810
image = vae_decode(w, vae, result, tiled_vae)
802811
return image
@@ -834,6 +843,9 @@ def generate(
834843
model, positive, negative = apply_control(
835844
w, model, positive, negative, cond.all_control, extent.initial, vae, models
836845
)
846+
positive = apply_reference_conditioning(
847+
w, positive, None, None, cond, vae, models.arch, checkpoint.tiled_vae
848+
)
837849
sample_params = _sampler_params(sampling, extent.initial)
838850
out_latent = w.sampler_custom_advanced(
839851
model, positive, negative, latent, models.arch, **sample_params
@@ -1092,8 +1104,8 @@ def refine(
10921104
model, positive, negative = apply_control(
10931105
w, model, positive, negative, cond.all_control, extent.desired, vae, models
10941106
)
1095-
positive = apply_edit_conditioning(
1096-
w, positive, in_image, latent, cond.all_control, vae, models.arch, checkpoint.tiled_vae
1107+
positive = apply_reference_conditioning(
1108+
w, positive, in_image, latent, cond, vae, models.arch, checkpoint.tiled_vae
10971109
)
10981110
sampler_params = _sampler_params(sampling, extent.desired)
10991111
sampler = w.sampler_custom_advanced(
@@ -1147,8 +1159,8 @@ def refine_region(
11471159
inpaint_model = w.apply_fooocus_inpaint(model, inpaint_patch, latent_inpaint)
11481160
else:
11491161
latent = vae_encode(w, vae, in_image, checkpoint.tiled_vae)
1150-
positive = apply_edit_conditioning(
1151-
w, positive, in_image, latent, cond.all_control, vae, models.arch, checkpoint.tiled_vae
1162+
positive = apply_reference_conditioning(
1163+
w, positive, in_image, latent, cond, vae, models.arch, checkpoint.tiled_vae
11521164
)
11531165
latent = w.set_latent_noise_mask(latent, initial_mask)
11541166
inpaint_model = model
@@ -1321,8 +1333,8 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds):
13211333

13221334
latent = vae_encode(w, vae, tile_image, checkpoint.tiled_vae)
13231335
latent = w.set_latent_noise_mask(latent, tile_mask)
1324-
positive = apply_edit_conditioning(
1325-
w, positive, tile_image, latent, control, vae, models.arch, checkpoint.tiled_vae
1336+
positive = apply_reference_conditioning(
1337+
w, positive, tile_image, latent, tile_cond, vae, models.arch, checkpoint.tiled_vae
13261338
)
13271339
sampler_params = _sampler_params(sampling, layout.bounds(i).extent)
13281340
sampler = w.sampler_custom_advanced(
@@ -1443,7 +1455,7 @@ def prepare_prompts(
14431455
"negative_prompt": cond.negative,
14441456
}
14451457
models = style.get_models([])
1446-
layer_replace = "Picture {}" if arch is Arch.qwen_e_p else ""
1458+
layer_replace = "Picture {}" if arch in (Arch.qwen_e_p, Arch.flux2) else ""
14471459

14481460
cond.style = style.style_prompt
14491461
cond.positive = strip_prompt_comments(cond.positive)

tests/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
Arch.sdxl: "RealVisXL_V5.0_fp16.safetensors",
2121
Arch.flux: "svdq-int4_r32-flux.1-krea-dev.safetensors",
2222
Arch.flux_k: "svdq-int4_r32-flux.1-kontext-dev.safetensors",
23+
Arch.flux2: "flux-2-klein-4b.safetensors",
2324
Arch.zimage: "z_image_turbo_bf16.safetensors",
2425
}

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def order(item: pytest.Item):
3434
return 11
3535
elif "cloud" in item.name:
3636
return 10
37+
elif "flux2" in item.name:
38+
return 4
3739
elif "flux" in item.name:
3840
return 3
3941
elif "sdxl" in item.name:

0 commit comments

Comments
 (0)