Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def remove_all_hooks(self):
accelerate.hooks.remove_hook_from_module(model, recurse=True)
self._all_hooks = []

def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda", model_cpu_offload_seq: Optional[str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
Expand Down Expand Up @@ -1051,7 +1051,11 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t

self._all_hooks = []
hook = None
for model_str in self.model_cpu_offload_seq.split("->"):

if model_cpu_offload_seq is None:
model_cpu_offload_seq = self.model_cpu_offload_seq

for model_str in model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None)

if not isinstance(model, torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,27 @@ def encode_prompt(
clip_skip=clip_skip,
clip_model_index=0,
)
print(f" ")
print(f" after get_clip_prompt_embeds(1):")
print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}")
print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}")
print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}")
print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}")
print(f" vae: {self.vae.device if self.vae is not None else 'None'}")
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
prompt=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=1,
)
print(f" ")
print(f" after get_clip_prompt_embeds(2):")
print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}")
print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}")
print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}")
print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}")
print(f" vae: {self.vae.device if self.vae is not None else 'None'}")
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)

t5_prompt_embed = self._get_t5_prompt_embeds(
Expand All @@ -432,6 +446,13 @@ def encode_prompt(
max_sequence_length=max_sequence_length,
device=device,
)
print(f" ")
print(f" after get_t5_prompt_embeds:")
print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}")
print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}")
print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}")
print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}")
print(f" vae: {self.vae.device if self.vae is not None else 'None'}")

clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
Expand Down Expand Up @@ -899,6 +920,13 @@ def __call__(
generator,
latents,
)
print(f" ")
print(f" before denoising loop:")
print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}")
print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}")
print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}")
print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}")
print(f" vae: {self.vae.device if self.vae is not None else 'None'}")

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down Expand Up @@ -974,6 +1002,13 @@ def __call__(

if XLA_AVAILABLE:
xm.mark_step()
print(f" ")
print(f" after denoising loop:")
print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}")
print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}")
print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}")
print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}")
print(f" vae: {self.vae.device if self.vae is not None else 'None'}")

if output_type == "latent":
image = latents
Expand All @@ -983,6 +1018,13 @@ def __call__(

image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
print(f" ")
print(f" after decode:")
print(f" text_encoder: {self.text_encoder.device if self.text_encoder is not None else 'None'}")
print(f" text_encoder_2: {self.text_encoder_2.device if self.text_encoder_2 is not None else 'None'}")
print(f" text_encoder_3: {self.text_encoder_3.device if self.text_encoder_3 is not None else 'None'}")
print(f" transformer: {self.transformer.device if self.transformer is not None else 'None'}")
print(f" vae: {self.vae.device if self.vae is not None else 'None'}")

# Offload all models
self.maybe_free_model_hooks()
Expand Down
Loading