@@ -127,6 +127,86 @@ def retrieve_latents(
127127 raise AttributeError ("Could not access latents of provided encoder_output" )
128128
129129
130+ def prepare_latents_img2img (vae , scheduler , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None , add_noise = True ):
131+
132+ if not isinstance (image , (torch .Tensor , PIL .Image .Image , list )):
133+ raise ValueError (
134+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
135+ )
136+
137+ image = image .to (device = device , dtype = dtype )
138+
139+ batch_size = batch_size * num_images_per_prompt
140+
141+ if image .shape [1 ] == 4 :
142+ init_latents = image
143+
144+ else :
145+ latents_mean = latents_std = None
146+ if hasattr (vae .config , "latents_mean" ) and vae .config .latents_mean is not None :
147+ latents_mean = torch .tensor (vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
148+ if hasattr (vae .config , "latents_std" ) and vae .config .latents_std is not None :
149+ latents_std = torch .tensor (vae .config .latents_std ).view (1 , 4 , 1 , 1 )
150+ # make sure the VAE is in float32 mode, as it overflows in float16
151+ if vae .config .force_upcast :
152+ image = image .float ()
153+ vae .to (dtype = torch .float32 )
154+
155+ if isinstance (generator , list ) and len (generator ) != batch_size :
156+ raise ValueError (
157+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
158+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
159+ )
160+
161+ elif isinstance (generator , list ):
162+ if image .shape [0 ] < batch_size and batch_size % image .shape [0 ] == 0 :
163+ image = torch .cat ([image ] * (batch_size // image .shape [0 ]), dim = 0 )
164+ elif image .shape [0 ] < batch_size and batch_size % image .shape [0 ] != 0 :
165+ raise ValueError (
166+ f"Cannot duplicate `image` of batch size { image .shape [0 ]} to effective batch_size { batch_size } "
167+ )
168+
169+ init_latents = [
170+ retrieve_latents (vae .encode (image [i : i + 1 ]), generator = generator [i ])
171+ for i in range (batch_size )
172+ ]
173+ init_latents = torch .cat (init_latents , dim = 0 )
174+ else :
175+ init_latents = retrieve_latents (vae .encode (image ), generator = generator )
176+
177+ if vae .config .force_upcast :
178+ vae .to (dtype )
179+
180+ init_latents = init_latents .to (dtype )
181+ if latents_mean is not None and latents_std is not None :
182+ latents_mean = latents_mean .to (device = device , dtype = dtype )
183+ latents_std = latents_std .to (device = device , dtype = dtype )
184+ init_latents = (init_latents - latents_mean ) * vae .config .scaling_factor / latents_std
185+ else :
186+ init_latents = vae .config .scaling_factor * init_latents
187+
188+ if batch_size > init_latents .shape [0 ] and batch_size % init_latents .shape [0 ] == 0 :
189+ # expand init_latents for batch_size
190+ additional_image_per_prompt = batch_size // init_latents .shape [0 ]
191+ init_latents = torch .cat ([init_latents ] * additional_image_per_prompt , dim = 0 )
192+ elif batch_size > init_latents .shape [0 ] and batch_size % init_latents .shape [0 ] != 0 :
193+ raise ValueError (
194+ f"Cannot duplicate `image` of batch size { init_latents .shape [0 ]} to { batch_size } text prompts."
195+ )
196+ else :
197+ init_latents = torch .cat ([init_latents ], dim = 0 )
198+
199+ if add_noise :
200+ shape = init_latents .shape
201+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
202+ # get latents
203+ init_latents = scheduler .add_noise (init_latents , noise , timestep )
204+
205+ latents = init_latents
206+
207+ return latents
208+
209+
130210class StableDiffusionXLInputStep (PipelineBlock ):
131211 model_name = "stable-diffusion-xl"
132212
@@ -751,89 +831,6 @@ def intermediates_inputs(self) -> List[InputParam]:
751831 def intermediates_outputs (self ) -> List [OutputParam ]:
752832 return [OutputParam ("latents" , type_hint = torch .Tensor , description = "The initial latents to use for the denoising process" )]
753833
754- # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components
755- # YiYi TODO: refactor using _encode_vae_image
756- @staticmethod
757- def prepare_latents_img2img (
758- components , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None , add_noise = True
759- ):
760- if not isinstance (image , (torch .Tensor , PIL .Image .Image , list )):
761- raise ValueError (
762- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
763- )
764-
765- image = image .to (device = device , dtype = dtype )
766-
767- batch_size = batch_size * num_images_per_prompt
768-
769- if image .shape [1 ] == 4 :
770- init_latents = image
771-
772- else :
773- latents_mean = latents_std = None
774- if hasattr (components .vae .config , "latents_mean" ) and components .vae .config .latents_mean is not None :
775- latents_mean = torch .tensor (components .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
776- if hasattr (components .vae .config , "latents_std" ) and components .vae .config .latents_std is not None :
777- latents_std = torch .tensor (components .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
778- # make sure the VAE is in float32 mode, as it overflows in float16
779- if components .vae .config .force_upcast :
780- image = image .float ()
781- components .vae .to (dtype = torch .float32 )
782-
783- if isinstance (generator , list ) and len (generator ) != batch_size :
784- raise ValueError (
785- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
786- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
787- )
788-
789- elif isinstance (generator , list ):
790- if image .shape [0 ] < batch_size and batch_size % image .shape [0 ] == 0 :
791- image = torch .cat ([image ] * (batch_size // image .shape [0 ]), dim = 0 )
792- elif image .shape [0 ] < batch_size and batch_size % image .shape [0 ] != 0 :
793- raise ValueError (
794- f"Cannot duplicate `image` of batch size { image .shape [0 ]} to effective batch_size { batch_size } "
795- )
796-
797- init_latents = [
798- retrieve_latents (components .vae .encode (image [i : i + 1 ]), generator = generator [i ])
799- for i in range (batch_size )
800- ]
801- init_latents = torch .cat (init_latents , dim = 0 )
802- else :
803- init_latents = retrieve_latents (components .vae .encode (image ), generator = generator )
804-
805- if components .vae .config .force_upcast :
806- components .vae .to (dtype )
807-
808- init_latents = init_latents .to (dtype )
809- if latents_mean is not None and latents_std is not None :
810- latents_mean = latents_mean .to (device = device , dtype = dtype )
811- latents_std = latents_std .to (device = device , dtype = dtype )
812- init_latents = (init_latents - latents_mean ) * components .vae .config .scaling_factor / latents_std
813- else :
814- init_latents = components .vae .config .scaling_factor * init_latents
815-
816- if batch_size > init_latents .shape [0 ] and batch_size % init_latents .shape [0 ] == 0 :
817- # expand init_latents for batch_size
818- additional_image_per_prompt = batch_size // init_latents .shape [0 ]
819- init_latents = torch .cat ([init_latents ] * additional_image_per_prompt , dim = 0 )
820- elif batch_size > init_latents .shape [0 ] and batch_size % init_latents .shape [0 ] != 0 :
821- raise ValueError (
822- f"Cannot duplicate `image` of batch size { init_latents .shape [0 ]} to { batch_size } text prompts."
823- )
824- else :
825- init_latents = torch .cat ([init_latents ], dim = 0 )
826-
827- if add_noise :
828- shape = init_latents .shape
829- noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
830- # get latents
831- init_latents = components .scheduler .add_noise (init_latents , noise , timestep )
832-
833- latents = init_latents
834-
835- return latents
836-
837834 @torch .no_grad ()
838835 def __call__ (self , components : StableDiffusionXLModularLoader , state : PipelineState ) -> PipelineState :
839836 block_state = self .get_block_state (state )
@@ -842,8 +839,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt
842839 block_state .device = components ._execution_device
843840 block_state .add_noise = True if block_state .denoising_start is None else False
844841 if block_state .latents is None :
845- block_state .latents = self .prepare_latents_img2img (
846- components ,
842+ block_state .latents = prepare_latents_img2img (
843+ components .vae ,
844+ components .scheduler ,
847845 block_state .image_latents ,
848846 block_state .latent_timestep ,
849847 block_state .batch_size ,
0 commit comments