|
19 | 19 | import torch |
20 | 20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast |
21 | 21 |
|
| 22 | +from ...configuration_utils import FrozenDict |
| 23 | +from ...image_processor import VaeImageProcessor |
22 | 24 | from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin |
| 25 | +from ...models import AutoencoderKL |
23 | 26 | from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers |
24 | 27 | from ..modular_pipeline import PipelineBlock, PipelineState |
25 | 28 | from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam |
@@ -50,6 +53,113 @@ def prompt_clean(text): |
50 | 53 | return text |
51 | 54 |
|
52 | 55 |
|
| 56 | +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
| 57 | +def retrieve_latents( |
| 58 | + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
| 59 | +): |
| 60 | + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| 61 | + return encoder_output.latent_dist.sample(generator) |
| 62 | + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| 63 | + return encoder_output.latent_dist.mode() |
| 64 | + elif hasattr(encoder_output, "latents"): |
| 65 | + return encoder_output.latents |
| 66 | + else: |
| 67 | + raise AttributeError("Could not access latents of provided encoder_output") |
| 68 | + |
| 69 | + |
| 70 | +class FluxVaeEncoderStep(PipelineBlock): |
| 71 | + model_name = "flux" |
| 72 | + |
| 73 | + @property |
| 74 | + def description(self) -> str: |
| 75 | + return "Vae Encoder step that encode the input image into a latent representation" |
| 76 | + |
| 77 | + @property |
| 78 | + def expected_components(self) -> List[ComponentSpec]: |
| 79 | + return [ |
| 80 | + ComponentSpec("vae", AutoencoderKL), |
| 81 | + ComponentSpec( |
| 82 | + "image_processor", |
| 83 | + VaeImageProcessor, |
| 84 | + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), |
| 85 | + default_creation_method="from_config", |
| 86 | + ), |
| 87 | + ] |
| 88 | + |
| 89 | + @property |
| 90 | + def inputs(self) -> List[InputParam]: |
| 91 | + return [ |
| 92 | + InputParam("image", required=True), |
| 93 | + InputParam("height"), |
| 94 | + InputParam("width"), |
| 95 | + ] |
| 96 | + |
| 97 | + @property |
| 98 | + def intermediate_inputs(self) -> List[InputParam]: |
| 99 | + return [ |
| 100 | + InputParam("generator"), |
| 101 | + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), |
| 102 | + InputParam( |
| 103 | + "preprocess_kwargs", |
| 104 | + type_hint=Optional[dict], |
| 105 | + description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", |
| 106 | + ), |
| 107 | + ] |
| 108 | + |
| 109 | + @property |
| 110 | + def intermediate_outputs(self) -> List[OutputParam]: |
| 111 | + return [ |
| 112 | + OutputParam( |
| 113 | + "image_latents", |
| 114 | + type_hint=torch.Tensor, |
| 115 | + description="The latents representing the reference image for image-to-image/inpainting generation", |
| 116 | + ) |
| 117 | + ] |
| 118 | + |
| 119 | + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae |
| 120 | + def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): |
| 121 | + if isinstance(generator, list): |
| 122 | + image_latents = [ |
| 123 | + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) |
| 124 | + ] |
| 125 | + image_latents = torch.cat(image_latents, dim=0) |
| 126 | + else: |
| 127 | + image_latents = retrieve_latents(vae.encode(image), generator=generator) |
| 128 | + |
| 129 | + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor |
| 130 | + |
| 131 | + return image_latents |
| 132 | + |
| 133 | + @torch.no_grad() |
| 134 | + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: |
| 135 | + block_state = self.get_block_state(state) |
| 136 | + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} |
| 137 | + block_state.device = components._execution_device |
| 138 | + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype |
| 139 | + |
| 140 | + block_state.image = components.image_processor.preprocess( |
| 141 | + block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs |
| 142 | + ) |
| 143 | + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) |
| 144 | + |
| 145 | + block_state.batch_size = block_state.image.shape[0] |
| 146 | + |
| 147 | + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) |
| 148 | + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: |
| 149 | + raise ValueError( |
| 150 | + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" |
| 151 | + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." |
| 152 | + ) |
| 153 | + |
| 154 | + block_state.image_latents = self._encode_vae_image( |
| 155 | + components, image=block_state.image, generator=block_state.generator |
| 156 | + ) |
| 157 | + |
| 158 | + self.set_block_state(state, block_state) |
| 159 | + |
| 160 | + return components, state |
| 161 | + |
| 162 | + |
53 | 163 | class FluxTextEncoderStep(PipelineBlock): |
54 | 164 | model_name = "flux" |
55 | 165 |
|
@@ -297,7 +407,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip |
297 | 407 | prompt_embeds=None, |
298 | 408 | pooled_prompt_embeds=None, |
299 | 409 | device=block_state.device, |
300 | | - num_images_per_prompt=1, # hardcoded for now. |
| 410 | + num_images_per_prompt=1, # TODO: hardcoded for now. |
301 | 411 | lora_scale=block_state.text_encoder_lora_scale, |
302 | 412 | ) |
303 | 413 |
|
|
0 commit comments