@@ -397,6 +397,7 @@ class Conditioning:
397397 control : list [Control ] = field (default_factory = list )
398398 regions : list [Region ] = field (default_factory = list )
399399 style_prompt : str = ""
400+ edit_reference : bool = False
400401
401402 @staticmethod
402403 def from_input (i : ConditioningInput ):
@@ -406,6 +407,7 @@ def from_input(i: ConditioningInput):
406407 [Control .from_input (c ) for c in i .control ],
407408 [Region .from_input (r , idx , i .language ) for idx , r in enumerate (i .regions )],
408409 i .style ,
410+ i .edit_reference ,
409411 )
410412
411413 def copy (self ):
@@ -415,6 +417,7 @@ def copy(self):
415417 [copy (c ) for c in self .control ],
416418 [r .copy () for r in self .regions ],
417419 self .style_prompt ,
420+ self .edit_reference ,
418421 )
419422
420423 def downscale (self , original : Extent , target : Extent ):
@@ -614,8 +617,8 @@ def apply_ip_adapter(
614617 models : ModelDict ,
615618 mask : Output | None = None ,
616619):
617- if models .arch . is_flux_like or models .arch .is_qwen_like :
618- return model # No IP-adapter for Flux or Qwen, using Style model instead
620+ if not ( models .arch is Arch . sd15 or models .arch .is_sdxl_like ) :
621+ return model
619622
620623 models = models .ip_adapter
621624
@@ -682,35 +685,39 @@ def apply_regional_ip_adapter(
682685 return model
683686
684687
685- def apply_edit_conditioning (
688+ def apply_reference_conditioning (
686689 w : ComfyWorkflow ,
687- cond : Output ,
688- input_image : Output ,
689- input_latent : Output ,
690- control_layers : list [ Control ] ,
690+ positive : Output ,
691+ input_image : Output | None ,
692+ input_latent : Output | None ,
693+ cond : Conditioning ,
691694 vae : Output ,
692695 arch : Arch ,
693696 tiled_vae : bool ,
694697):
695- if not arch .is_edit :
696- return cond
697-
698- extra_input = [c .image for c in control_layers if c .mode .is_ip_adapter ]
699- if len (extra_input ) == 0 :
700- return w .reference_latent (cond , input_latent )
701-
702- if arch == Arch .qwen_e_p :
703- extra_images = [i .load (w ) for i in extra_input ]
704- cond = w .reference_latent (cond , input_latent )
705- for extra_image in extra_images :
706- latent = vae_encode (w , vae , extra_image , tiled_vae )
707- cond = w .reference_latent (cond , latent )
708- return cond
709- else :
710- input = w .image_stitch ([input_image ] + [i .load (w ) for i in extra_input ])
711- latent = vae_encode (w , vae , input , tiled_vae )
712- cond = w .reference_latent (cond , latent )
713- return cond
698+ if not arch .supports_edit :
699+ return positive
700+
701+ extra_input = (c .image for c in cond .all_control if c .mode .is_ip_adapter )
702+ extra_images = [i .load (w ) for i in extra_input ]
703+ match arch :
704+ case Arch .flux2 | Arch .qwen_e_p :
705+ if cond .edit_reference and input_latent :
706+ positive = w .reference_latent (positive , input_latent )
707+ for extra_image in extra_images :
708+ latent = vae_encode (w , vae , extra_image , tiled_vae )
709+ positive = w .reference_latent (positive , latent )
710+ case Arch .flux_k | Arch .qwen_e :
711+ if len (extra_images ) > 0 :
712+ if cond .edit_reference and input_image :
713+ extra_images .insert (0 , input_image )
714+ input = w .image_stitch (extra_images )
715+ latent = vae_encode (w , vae , input , tiled_vae )
716+ positive = w .reference_latent (positive , latent )
717+ elif cond .edit_reference and input_latent :
718+ positive = w .reference_latent (positive , input_latent )
719+
720+ return positive
714721
715722
716723def scale (
@@ -796,7 +803,9 @@ def scale_refine_and_decode(
796803 model , positive , negative = apply_control (
797804 w , model , positive , negative , cond .all_control , extent .desired , vae , models
798805 )
799- positive = apply_edit_conditioning (w , positive , upscale , latent , [], vae , arch , tiled_vae )
806+ positive = apply_reference_conditioning (
807+ w , positive , upscale , latent , cond , vae , arch , tiled_vae
808+ )
800809 result = w .sampler_custom_advanced (model , positive , negative , latent , arch , ** params )
801810 image = vae_decode (w , vae , result , tiled_vae )
802811 return image
@@ -834,6 +843,9 @@ def generate(
834843 model , positive , negative = apply_control (
835844 w , model , positive , negative , cond .all_control , extent .initial , vae , models
836845 )
846+ positive = apply_reference_conditioning (
847+ w , positive , None , None , cond , vae , models .arch , checkpoint .tiled_vae
848+ )
837849 sample_params = _sampler_params (sampling , extent .initial )
838850 out_latent = w .sampler_custom_advanced (
839851 model , positive , negative , latent , models .arch , ** sample_params
@@ -1092,8 +1104,8 @@ def refine(
10921104 model , positive , negative = apply_control (
10931105 w , model , positive , negative , cond .all_control , extent .desired , vae , models
10941106 )
1095- positive = apply_edit_conditioning (
1096- w , positive , in_image , latent , cond . all_control , vae , models .arch , checkpoint .tiled_vae
1107+ positive = apply_reference_conditioning (
1108+ w , positive , in_image , latent , cond , vae , models .arch , checkpoint .tiled_vae
10971109 )
10981110 sampler_params = _sampler_params (sampling , extent .desired )
10991111 sampler = w .sampler_custom_advanced (
@@ -1147,8 +1159,8 @@ def refine_region(
11471159 inpaint_model = w .apply_fooocus_inpaint (model , inpaint_patch , latent_inpaint )
11481160 else :
11491161 latent = vae_encode (w , vae , in_image , checkpoint .tiled_vae )
1150- positive = apply_edit_conditioning (
1151- w , positive , in_image , latent , cond . all_control , vae , models .arch , checkpoint .tiled_vae
1162+ positive = apply_reference_conditioning (
1163+ w , positive , in_image , latent , cond , vae , models .arch , checkpoint .tiled_vae
11521164 )
11531165 latent = w .set_latent_noise_mask (latent , initial_mask )
11541166 inpaint_model = model
@@ -1321,8 +1333,8 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds):
13211333
13221334 latent = vae_encode (w , vae , tile_image , checkpoint .tiled_vae )
13231335 latent = w .set_latent_noise_mask (latent , tile_mask )
1324- positive = apply_edit_conditioning (
1325- w , positive , tile_image , latent , control , vae , models .arch , checkpoint .tiled_vae
1336+ positive = apply_reference_conditioning (
1337+ w , positive , tile_image , latent , tile_cond , vae , models .arch , checkpoint .tiled_vae
13261338 )
13271339 sampler_params = _sampler_params (sampling , layout .bounds (i ).extent )
13281340 sampler = w .sampler_custom_advanced (
@@ -1443,7 +1455,7 @@ def prepare_prompts(
14431455 "negative_prompt" : cond .negative ,
14441456 }
14451457 models = style .get_models ([])
1446- layer_replace = "Picture {}" if arch is Arch .qwen_e_p else ""
1458+ layer_replace = "Picture {}" if arch in ( Arch .qwen_e_p , Arch . flux2 ) else ""
14471459
14481460 cond .style = style .style_prompt
14491461 cond .positive = strip_prompt_comments (cond .positive )
0 commit comments