|
18 | 18 | import numpy as np |
19 | 19 | import torch |
20 | 20 |
|
| 21 | +from ...configuration_utils import FrozenDict |
| 22 | +from ...image_processor import VaeImageProcessor |
21 | 23 | from ...models import AutoencoderKL |
22 | 24 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
23 | 25 | from ...utils import logging |
@@ -182,15 +184,15 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): |
182 | 184 | return latent_image_ids.to(device=device, dtype=dtype) |
183 | 185 |
|
184 | 186 |
|
185 | | -# Cannot use "# Copied from" because it introduces weird indentation errors. |
186 | | -def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): |
| 187 | +def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator, sample_mode: str = "sample"): |
187 | 188 | if isinstance(generator, list): |
188 | 189 | image_latents = [ |
189 | | - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) |
| 190 | + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) |
| 191 | + for i in range(image.shape[0]) |
190 | 192 | ] |
191 | 193 | image_latents = torch.cat(image_latents, dim=0) |
192 | 194 | else: |
193 | | - image_latents = retrieve_latents(vae.encode(image), generator=generator) |
| 195 | + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) |
194 | 196 |
|
195 | 197 | image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor |
196 | 198 |
|
@@ -687,3 +689,213 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip |
687 | 689 | self.set_block_state(state, block_state) |
688 | 690 |
|
689 | 691 | return components, state |
| 692 | + |
| 693 | + |
| 694 | +class FluxKontextPrepareLatentsStep(ModularPipelineBlocks): |
| 695 | + model_name = "flux_kontext" |
| 696 | + |
| 697 | + @property |
| 698 | + def expected_components(self) -> List[ComponentSpec]: |
| 699 | + return [ |
| 700 | + ComponentSpec("vae", AutoencoderKL), |
| 701 | + ComponentSpec( |
| 702 | + "image_processor", |
| 703 | + VaeImageProcessor, |
| 704 | + config=FrozenDict({"vae_scale_factor": 16}), |
| 705 | + default_creation_method="from_config", |
| 706 | + ), |
| 707 | + ] |
| 708 | + |
| 709 | + @property |
| 710 | + def description(self) -> str: |
| 711 | + return "Prepare latents step that prepares the latents for the image-to-image generation process with Flux Kontext" |
| 712 | + |
| 713 | + @property |
| 714 | + def inputs(self) -> List[InputParam]: |
| 715 | + return [ |
| 716 | + InputParam("height", type_hint=int), |
| 717 | + InputParam("width", type_hint=int), |
| 718 | + InputParam("max_area", type_hint=int, default=1024**2), |
| 719 | + InputParam("latents", type_hint=Optional[torch.Tensor]), |
| 720 | + InputParam("num_images_per_prompt", type_hint=int, default=1), |
| 721 | + InputParam("generator"), |
| 722 | + InputParam( |
| 723 | + "batch_size", |
| 724 | + required=True, |
| 725 | + type_hint=int, |
| 726 | + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", |
| 727 | + ), |
| 728 | + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), |
| 729 | + ] |
| 730 | + |
| 731 | + @property |
| 732 | + def intermediate_outputs(self) -> List[OutputParam]: |
| 733 | + return [ |
| 734 | + OutputParam( |
| 735 | + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" |
| 736 | + ), |
| 737 | + OutputParam( |
| 738 | + "image_latents", type_hint=torch.Tensor, description="Latents computed from the input image(s)." |
| 739 | + ), |
| 740 | + OutputParam( |
| 741 | + "latent_ids", |
| 742 | + type_hint=torch.Tensor, |
| 743 | + description="IDs computed from the latent sequence needed for RoPE", |
| 744 | + ), |
| 745 | + OutputParam( |
| 746 | + "image_ids", |
| 747 | + type_hint=torch.Tensor, |
| 748 | + description="IDs computed from the image sequence needed for RoPE", |
| 749 | + ), |
| 750 | + ] |
| 751 | + |
| 752 | + @staticmethod |
| 753 | + def check_inputs(components, block_state): |
| 754 | + if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or ( |
| 755 | + block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0 |
| 756 | + ): |
| 757 | + logger.warning( |
| 758 | + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." |
| 759 | + ) |
| 760 | + |
| 761 | + @staticmethod |
| 762 | + def preprocess_image( |
| 763 | + image, image_processor: VaeImageProcessor, vae_scale_factor: int, latent_channels: int, _auto_resize=True |
| 764 | + ): |
| 765 | + from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS |
| 766 | + |
| 767 | + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): |
| 768 | + multiple_of = vae_scale_factor * 2 |
| 769 | + img = image[0] if isinstance(image, list) else image |
| 770 | + image_height, image_width = image_processor.get_default_height_width(img) |
| 771 | + aspect_ratio = image_width / image_height |
| 772 | + if _auto_resize: |
| 773 | + # Kontext is trained on specific resolutions, using one of them is recommended |
| 774 | + _, image_width, image_height = min( |
| 775 | + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS |
| 776 | + ) |
| 777 | + image_width = image_width // multiple_of * multiple_of |
| 778 | + image_height = image_height // multiple_of * multiple_of |
| 779 | + image = image_processor.resize(image, image_height, image_width) |
| 780 | + image = image_processor.preprocess(image, image_height, image_width) |
| 781 | + return image |
| 782 | + |
| 783 | + @staticmethod |
| 784 | + def prepare_latents( |
| 785 | + comp, |
| 786 | + image, |
| 787 | + batch_size, |
| 788 | + num_channels_latents, |
| 789 | + height, |
| 790 | + width, |
| 791 | + dtype, |
| 792 | + device, |
| 793 | + generator, |
| 794 | + latents=None, |
| 795 | + ): |
| 796 | + # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over |
| 797 | + # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were |
| 798 | + # to go with the "# Copied from ..." approach. Or maybe there's a way? |
| 799 | + |
| 800 | + # VAE applies 8x compression on images but we must also account for packing which requires |
| 801 | + # latent height and width to be divisible by 2. |
| 802 | + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) |
| 803 | + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) |
| 804 | + shape = (batch_size, num_channels_latents, height, width) |
| 805 | + |
| 806 | + image_latents = image_ids = None |
| 807 | + if image is not None: |
| 808 | + image = image.to(device=device, dtype=dtype) |
| 809 | + if image.shape[1] != num_channels_latents: |
| 810 | + image_latents = _encode_vae_image(image=image, generator=generator, sample_mode="argmax") |
| 811 | + else: |
| 812 | + image_latents = image |
| 813 | + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: |
| 814 | + # expand init_latents for batch_size |
| 815 | + additional_image_per_prompt = batch_size // image_latents.shape[0] |
| 816 | + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) |
| 817 | + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: |
| 818 | + raise ValueError( |
| 819 | + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." |
| 820 | + ) |
| 821 | + else: |
| 822 | + image_latents = torch.cat([image_latents], dim=0) |
| 823 | + |
| 824 | + image_latent_height, image_latent_width = image_latents.shape[2:] |
| 825 | + image_latents = _pack_latents( |
| 826 | + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width |
| 827 | + ) |
| 828 | + image_ids = _prepare_latent_image_ids( |
| 829 | + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype |
| 830 | + ) |
| 831 | + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 |
| 832 | + image_ids[..., 0] = 1 |
| 833 | + |
| 834 | + latent_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) |
| 835 | + |
| 836 | + if latents is None: |
| 837 | + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| 838 | + latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) |
| 839 | + else: |
| 840 | + latents = latents.to(device=device, dtype=dtype) |
| 841 | + |
| 842 | + return latents, image_latents, latent_ids, image_ids |
| 843 | + |
| 844 | + @torch.no_grad() |
| 845 | + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: |
| 846 | + block_state = self.get_block_state(state) |
| 847 | + |
| 848 | + block_state.height = block_state.height or components.default_height |
| 849 | + block_state.width = block_state.width or components.default_width |
| 850 | + block_state.device = components._execution_device |
| 851 | + block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? |
| 852 | + block_state.num_channels_latents = components.num_channels_latents |
| 853 | + |
| 854 | + self.check_inputs(components, block_state) |
| 855 | + |
| 856 | + # Adjust height and width if needed. |
| 857 | + max_area = block_state.max_area |
| 858 | + original_height, original_width = block_state.height, block_state.width |
| 859 | + aspect_ratio = original_width / original_height |
| 860 | + width = round((max_area * aspect_ratio) ** 0.5) |
| 861 | + height = round((max_area / aspect_ratio) ** 0.5) |
| 862 | + |
| 863 | + multiple_of = components.vae_scale_factor * 2 |
| 864 | + width = width // multiple_of * multiple_of |
| 865 | + height = height // multiple_of * multiple_of |
| 866 | + |
| 867 | + if height != original_height or width != original_width: |
| 868 | + logger.warning( |
| 869 | + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." |
| 870 | + ) |
| 871 | + block_state.height = height |
| 872 | + block_state.width = width |
| 873 | + |
| 874 | + # Process input image(s). |
| 875 | + # `_auto_resize` is currently forced to True. Since it's private anyway, I thought of not adding it. |
| 876 | + image = block_state.image |
| 877 | + block_state.image = self.preprocess_image( |
| 878 | + image=image, |
| 879 | + image_processor=components.image_processor, |
| 880 | + vae_scale_factor=components.vae_scale_factor, |
| 881 | + latent_channels=components.num_channels_latents, |
| 882 | + ) |
| 883 | + |
| 884 | + batch_size = block_state.batch_size * block_state.num_images_per_prompt |
| 885 | + block_state.latents, block_state.image_latents, block_state.latent_ids, block_state.image_ids = ( |
| 886 | + self.prepare_latents( |
| 887 | + components, |
| 888 | + batch_size, |
| 889 | + block_state.num_channels_latents, |
| 890 | + block_state.height, |
| 891 | + block_state.width, |
| 892 | + block_state.dtype, |
| 893 | + block_state.device, |
| 894 | + block_state.generator, |
| 895 | + block_state.latents, |
| 896 | + ) |
| 897 | + ) |
| 898 | + |
| 899 | + self.set_block_state(state, block_state) |
| 900 | + |
| 901 | + return components, state |
0 commit comments