Skip to content

Commit f00ed73

Browse files
committed
Support Flux 2 Klein 9B variant
1 parent 99f04fb commit f00ed73

File tree

8 files changed

+37
-21
lines changed

8 files changed

+37
-21
lines changed

ai_diffusion/comfy_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ def _find_text_encoder_models(model_list: Sequence[str]):
727727
kind = ResourceKind.text_encoder
728728
return {
729729
resource_id(kind, Arch.all, te): _find_model(model_list, kind, Arch.all, te)
730-
for te in ["clip_l", "clip_g", "t5", "qwen", "qwen_3"]
730+
for te in ["clip_l", "clip_g", "t5", "qwen", "qwen_3_4b", "qwen_3_8b"]
731731
}
732732

733733

ai_diffusion/comfy_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def empty_latent_image(self, extent: Extent, arch: Arch, batch_size=1):
601601
w, h = extent.width, extent.height
602602
if arch.is_flux_like or arch.is_qwen_like or arch in (Arch.sd3, Arch.chroma, Arch.zimage):
603603
return self.add("EmptySD3LatentImage", 1, width=w, height=h, batch_size=batch_size)
604-
if arch is Arch.flux2:
604+
if arch.is_flux2:
605605
return self.add("EmptyFlux2LatentImage", 1, width=w, height=h, batch_size=batch_size)
606606
else:
607607
return self.add("EmptyLatentImage", 1, width=w, height=h, batch_size=batch_size)

ai_diffusion/resolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def prepare_diffusion_input(
183183

184184
# The checkpoint may require a different resolution than what is requested.
185185
mult = 8
186-
if arch.is_flux_like or arch in (Arch.chroma, Arch.flux2):
186+
if arch.is_flux_like or arch is Arch.chroma or arch.is_flux2:
187187
mult = 16
188188
if arch is Arch.sd3:
189189
mult = 64

ai_diffusion/resources.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ class Arch(Enum):
8080
sd3 = "SD 3"
8181
flux = "Flux"
8282
flux_k = "Flux Kontext"
83-
flux2 = "Flux 2 Klein 4B"
83+
flux2_4b = "Flux 2 Klein 4B"
84+
flux2_9b = "Flux 2 Klein 9B"
8485
illu = "Illustrious"
8586
illu_v = "Illustrious v-prediction"
8687
chroma = "Chroma"
@@ -109,7 +110,9 @@ def from_string(string: str, model_type: str = "eps", filename: str | None = Non
109110
if string == "flux" or string == "flux-schnell":
110111
return Arch.flux
111112
if string == "flux2" and model_type == "klein-4b":
112-
return Arch.flux2
113+
return Arch.flux2_4b
114+
if string == "flux2" and model_type == "klein-9b":
115+
return Arch.flux2_9b
113116
if string == "illu":
114117
return Arch.illu
115118
if string == "illu_v":
@@ -190,7 +193,7 @@ def is_edit(self): # edit models make changes to input images
190193

191194
@property
192195
def supports_edit(self): # includes text-to-image models that can also edit
193-
return self.is_edit or self is Arch.flux2
196+
return self.is_edit or self.is_flux2
194197

195198
@property
196199
def is_sdxl_like(self):
@@ -201,6 +204,10 @@ def is_sdxl_like(self):
201204
def is_flux_like(self):
202205
return self in [Arch.flux, Arch.flux_k]
203206

207+
@property
208+
def is_flux2(self):
209+
return self in [Arch.flux2_4b, Arch.flux2_9b]
210+
204211
@property
205212
def is_qwen_like(self):
206213
return self in [Arch.qwen, Arch.qwen_e, Arch.qwen_e_p, Arch.qwen_l]
@@ -216,12 +223,16 @@ def text_encoders(self):
216223
return ["clip_l", "clip_g"]
217224
case Arch.flux | Arch.flux_k:
218225
return ["clip_l", "t5"]
226+
case Arch.flux2_4b:
227+
return ["qwen_3_4b"]
228+
case Arch.flux2_9b:
229+
return ["qwen_3_8b"]
219230
case Arch.chroma:
220231
return ["t5"]
221232
case Arch.qwen | Arch.qwen_e | Arch.qwen_e_p | Arch.qwen_l:
222233
return ["qwen"]
223-
case Arch.zimage | Arch.flux2:
224-
return ["qwen_3"]
234+
case Arch.zimage:
235+
return ["qwen_3_4b"]
225236
raise ValueError(f"Unsupported architecture: {self}")
226237

227238
@staticmethod
@@ -232,7 +243,8 @@ def list():
232243
Arch.sd3,
233244
Arch.flux,
234245
Arch.flux_k,
235-
Arch.flux2,
246+
Arch.flux2_4b,
247+
Arch.flux2_9b,
236248
Arch.illu,
237249
Arch.illu_v,
238250
Arch.chroma,
@@ -753,15 +765,17 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
753765
resource_id(ResourceKind.text_encoder, Arch.all, "clip_g"): ["clip_g"],
754766
resource_id(ResourceKind.text_encoder, Arch.all, "t5"): ["t5xxl_fp16", "t5xxl_fp8_e4m3fn", "t5xxl_fp8_e4m3fn_scaled", "t5-v1_1-xxl", "t5"],
755767
resource_id(ResourceKind.text_encoder, Arch.all, "qwen"): ["qwen_2.5_vl_7b", "qwen_2", "qwen-2", "qwen"],
756-
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3"): ["qwen_3_4b", "qwen3-4b", "qwen_3", "qwen-3"],
768+
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_4b"): ["qwen_3_4b", "qwen3-4b", "qwen_3", "qwen-3"],
769+
resource_id(ResourceKind.text_encoder, Arch.all, "qwen_3_8b"): ["qwen_3_8b", "qwen3-8b"],
757770
resource_id(ResourceKind.vae, Arch.sd15, "default"): ["vae-ft-mse-840000-ema"],
758771
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
759772
resource_id(ResourceKind.vae, Arch.illu, "default"): ["sdxl_vae"],
760773
resource_id(ResourceKind.vae, Arch.illu_v, "default"): ["sdxl_vae"],
761774
resource_id(ResourceKind.vae, Arch.sd3, "default"): ["sd3"],
762775
resource_id(ResourceKind.vae, Arch.flux, "default"): ["flux-", "flux_", "flux/", "flux1", "ae.s"],
763776
resource_id(ResourceKind.vae, Arch.flux_k, "default"): ["flux-", "flux_", "flux/", "flux1", "ae.s"],
764-
resource_id(ResourceKind.vae, Arch.flux2, "default"): ["flux2"],
777+
resource_id(ResourceKind.vae, Arch.flux2_4b, "default"): ["flux2"],
778+
resource_id(ResourceKind.vae, Arch.flux2_9b, "default"): ["flux2"],
765779
resource_id(ResourceKind.vae, Arch.chroma, "default"): ["flux-", "flux_", "flux/", "flux1", "ae.s"],
766780
resource_id(ResourceKind.vae, Arch.qwen, "default"): ["qwen"],
767781
resource_id(ResourceKind.vae, Arch.qwen_e, "default"): ["qwen"],

ai_diffusion/ui/theme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def checkpoint_icon(arch: Arch, format: FileFormat | None = None, client: Client
6363
return icon("sd-version-flux")
6464
elif arch is Arch.flux_k:
6565
return icon("sd-version-flux-k")
66-
elif arch is Arch.flux2:
66+
elif arch.is_flux2:
6767
return icon("sd-version-flux-2")
6868
elif arch is Arch.illu:
6969
return icon("sd-version-illu")

ai_diffusion/workflow.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,17 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
133133
clip = w.load_dual_clip(te["clip_g"], te["clip_l"], type="sd3")
134134
case Arch.flux | Arch.flux_k:
135135
clip = w.load_dual_clip(te["clip_l"], te["t5"], type="flux")
136-
case Arch.flux2:
137-
clip = w.load_clip(te["qwen_3"], type="flux2")
136+
case Arch.flux2_4b:
137+
clip = w.load_clip(te["qwen_3_4b"], type="flux2")
138+
case Arch.flux2_9b:
139+
clip = w.load_clip(te["qwen_3_8b"], type="flux2")
138140
case Arch.chroma:
139141
clip = w.load_clip(te["t5"], type="chroma")
140142
clip = w.t5_tokenizer_options(clip, min_padding=1, min_length=0)
141143
case Arch.qwen | Arch.qwen_e | Arch.qwen_e_p | Arch.qwen_l:
142144
clip = w.load_clip(te["qwen"], type="qwen_image")
143145
case Arch.zimage:
144-
clip = w.load_clip(te["qwen_3"], type="lumina2")
146+
clip = w.load_clip(te["qwen_3_4b"], type="lumina2")
145147
case _:
146148
raise RuntimeError(f"No text encoder for model architecture {arch.name}")
147149

@@ -701,7 +703,7 @@ def apply_reference_conditioning(
701703
extra_input = (c.image for c in cond.all_control if c.mode.is_ip_adapter)
702704
extra_images = [i.load(w) for i in extra_input]
703705
match arch:
704-
case Arch.flux2 | Arch.qwen_e_p:
706+
case Arch.flux2_4b | Arch.flux2_9b | Arch.qwen_e_p:
705707
if cond.edit_reference and input_latent:
706708
positive = w.reference_latent(positive, input_latent)
707709
for extra_image in extra_images:
@@ -1455,7 +1457,7 @@ def prepare_prompts(
14551457
"negative_prompt": cond.negative,
14561458
}
14571459
models = style.get_models([])
1458-
layer_replace = "Picture {}" if arch in (Arch.qwen_e_p, Arch.flux2) else ""
1460+
layer_replace = "Picture {}" if arch is Arch.qwen_e_p or arch.is_flux2 else ""
14591461

14601462
cond.style = style.style_prompt
14611463
cond.positive = strip_prompt_comments(cond.positive)

tests/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
Arch.sdxl: "RealVisXL_V5.0_fp16.safetensors",
2121
Arch.flux: "svdq-int4_r32-flux.1-krea-dev.safetensors",
2222
Arch.flux_k: "svdq-int4_r32-flux.1-kontext-dev.safetensors",
23-
Arch.flux2: "flux-2-klein-4b.safetensors",
23+
Arch.flux2_4b: "flux-2-klein-4b.safetensors",
2424
Arch.zimage: "z_image_turbo_bf16.safetensors",
2525
}

tests/test_workflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def default_style(client: Client, sd_ver=Arch.sd15):
8585
style.sampler = "Flux - Euler simple"
8686
style.cfg_scale = 1.0
8787
style.sampler_steps = 8
88-
if sd_ver is Arch.flux2:
88+
if sd_ver.is_flux2:
8989
style.sampler = "Flux 2 - Euler"
9090
style.cfg_scale = 1.0
9191
style.sampler_steps = 4
@@ -815,7 +815,7 @@ def test_refine_live(qtapp, client, sdver):
815815
run_and_save(qtapp, client, job, f"test_refine_live_{sdver.name}")
816816

817817

818-
@pytest.mark.parametrize("arch", [Arch.flux_k, Arch.flux2])
818+
@pytest.mark.parametrize("arch", [Arch.flux_k, Arch.flux2_4b])
819819
def test_edit(qtapp, local_client, arch):
820820
image = Image.load(image_dir / "flowers.webp")
821821
style = default_style(local_client, arch)
@@ -825,7 +825,7 @@ def test_edit(qtapp, local_client, arch):
825825
run_and_save(qtapp, local_client, job, f"test_edit_{arch.name}")
826826

827827

828-
@pytest.mark.parametrize("arch", [Arch.flux_k, Arch.flux2])
828+
@pytest.mark.parametrize("arch", [Arch.flux_k, Arch.flux2_4b])
829829
def test_edit_selection(qtapp, local_client, arch):
830830
image = Image.load(image_dir / "flowers.webp")
831831
mask = Mask.load(image_dir / "flowers_mask.png")

0 commit comments

Comments
 (0)