From a16346826dd336a3638bfc332b63a7f9e14f5d74 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 14 Dec 2024 04:20:30 +0100 Subject: [PATCH] add --- src/diffusers/pipelines/pipeline_utils.py | 8 +++- .../pipeline_stable_diffusion_3.py | 42 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a504184ea2f2..9ff08da273ff 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -994,7 +994,7 @@ def remove_all_hooks(self): accelerate.hooks.remove_hook_from_module(model, recurse=True) self._all_hooks = [] - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda", model_cpu_offload_seq: Optional[str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -1051,7 +1051,11 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self._all_hooks = [] hook = None - for model_str in self.model_cpu_offload_seq.split("->"): + + if model_cpu_offload_seq is None: + model_cpu_offload_seq = self.model_cpu_offload_seq + + for model_str in model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) if not isinstance(model, torch.nn.Module): diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 513f86441c3a..cd4198d79161 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -417,6 +417,13 @@ def encode_prompt( clip_skip=clip_skip, clip_model_index=0, ) + print(f" ") + print(f" after get_clip_prompt_embeds(1):") + print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}") + print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}") + print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}") + print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}") + print(f" vae: {self.vae.device if self.vae is not None else 'None'}") prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, @@ -424,6 +431,13 @@ def encode_prompt( clip_skip=clip_skip, clip_model_index=1, ) + print(f" ") + print(f" after get_clip_prompt_embeds(2):") + print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}") + print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}") + print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}") + print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}") + print(f" vae: {self.vae.device if self.vae is not None else 'None'}") clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( @@ -432,6 +446,13 @@ def encode_prompt( max_sequence_length=max_sequence_length, device=device, ) + print(f" ") + print(f" after get_t5_prompt_embeds:") + print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}") + print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}") + print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}") + print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}") + print(f" vae: {self.vae.device if self.vae is not None else 'None'}") clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) @@ -899,6 +920,13 @@ def __call__( generator, latents, ) + print(f" ") + print(f" before denoising loop:") + print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}") + print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}") + print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}") + print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}") + print(f" vae: {self.vae.device if self.vae is not None else 'None'}") # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -974,6 +1002,13 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + print(f" ") + print(f" after denoising loop:") + print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}") + print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}") + print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}") + print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}") + print(f" vae: {self.vae.device if self.vae is not None else 'None'}") if output_type == "latent": image = latents @@ -983,6 +1018,13 @@ def __call__( image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + print(f" ") + print(f" after decode:") + print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}") + print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}") + print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}") + print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}") + print(f" vae: {self.vae.device if self.vae is not None else 'None'}") # Offload all models self.maybe_free_model_hooks()