Skip to content

Commit a79f994

Browse files
committed
fix community pipeline for semantic guidance for flux
1 parent 481c88a commit a79f994

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

examples/community/pipeline_flux_semantic_guidance.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __init__(
230230
)
231231
self.default_sample_size = 128
232232

233+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
233234
def _get_t5_prompt_embeds(
234235
self,
235236
prompt: Union[str, List[str]] = None,
@@ -279,6 +280,7 @@ def _get_t5_prompt_embeds(
279280

280281
return prompt_embeds
281282

283+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
282284
def _get_clip_prompt_embeds(
283285
self,
284286
prompt: Union[str, List[str]],
@@ -323,6 +325,7 @@ def _get_clip_prompt_embeds(
323325

324326
return prompt_embeds
325327

328+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
326329
def encode_prompt(
327330
self,
328331
prompt: Union[str, List[str]],
@@ -513,6 +516,7 @@ def encode_text_with_editing(
513516
enabled_editing_prompts,
514517
)
515518

519+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
516520
def encode_image(self, image, device, num_images_per_prompt):
517521
dtype = next(self.image_encoder.parameters()).dtype
518522

@@ -524,6 +528,7 @@ def encode_image(self, image, device, num_images_per_prompt):
524528
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
525529
return image_embeds
526530

531+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
527532
def prepare_ip_adapter_image_embeds(
528533
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
529534
):
@@ -555,6 +560,7 @@ def prepare_ip_adapter_image_embeds(
555560

556561
return ip_adapter_image_embeds
557562

563+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
558564
def check_inputs(
559565
self,
560566
prompt,
@@ -633,6 +639,7 @@ def check_inputs(
633639
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
634640

635641
@staticmethod
642+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
636643
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
637644
latent_image_ids = torch.zeros(height, width, 3)
638645
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
@@ -647,6 +654,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
647654
return latent_image_ids.to(device=device, dtype=dtype)
648655

649656
@staticmethod
657+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
650658
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
651659
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
652660
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -655,6 +663,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
655663
return latents
656664

657665
@staticmethod
666+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
658667
def _unpack_latents(latents, height, width, vae_scale_factor):
659668
batch_size, num_patches, channels = latents.shape
660669

@@ -670,20 +679,23 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
670679

671680
return latents
672681

682+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
673683
def enable_vae_slicing(self):
674684
r"""
675685
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
676686
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
677687
"""
678688
self.vae.enable_slicing()
679689

690+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
680691
def disable_vae_slicing(self):
681692
r"""
682693
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
683694
computing decoding in one step.
684695
"""
685696
self.vae.disable_slicing()
686697

698+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
687699
def enable_vae_tiling(self):
688700
r"""
689701
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -692,13 +704,15 @@ def enable_vae_tiling(self):
692704
"""
693705
self.vae.enable_tiling()
694706

707+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
695708
def disable_vae_tiling(self):
696709
r"""
697710
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
698711
computing decoding in one step.
699712
"""
700713
self.vae.disable_tiling()
701714

715+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
702716
def prepare_latents(
703717
self,
704718
batch_size,
@@ -1171,7 +1185,7 @@ def __call__(
11711185
device=device,
11721186
dtype=noise_guidance.dtype,
11731187
)
1174-
# noise_guidance_edit = torch.zeros_like(noise_guidance)
1188+
11751189
warmup_inds = []
11761190
for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
11771191
if isinstance(edit_guidance_scale, list):
@@ -1244,9 +1258,6 @@ def __call__(
12441258
)
12451259

12461260
noise_guidance_edit[c, :, :, :] = noise_guidance_edit_tmp
1247-
# noise_guidance_edit[c] = noise_guidance_edit_tmp
1248-
1249-
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
12501261

12511262
warmup_inds = torch.tensor(warmup_inds).to(device)
12521263
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
@@ -1258,7 +1269,6 @@ def __call__(
12581269
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
12591270
)
12601271
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
1261-
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
12621272

12631273
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
12641274
noise_guidance_edit_tmp = torch.einsum(

0 commit comments

Comments
 (0)