Skip to content

Commit f615f00

Browse files
hlkysayakpaul
andauthored
Fix enable_sequential_cpu_offload in test_kandinsky_combined (#10324)
Co-authored-by: Sayak Paul <[email protected]>
1 parent 6aaa051 commit f615f00

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,15 @@ def __init__(
193193
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
194194
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
195195

196-
def enable_sequential_cpu_offload(self, gpu_id=0):
196+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
197197
r"""
198198
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
199199
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
200200
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
201201
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
202202
"""
203-
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
204-
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
203+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
204+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
205205

206206
def progress_bar(self, iterable=None, total=None):
207207
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -411,16 +411,16 @@ def __init__(
411411
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
412412
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
413413

414-
def enable_sequential_cpu_offload(self, gpu_id=0):
414+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
415415
r"""
416416
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
417417
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
418418
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
419419
Note that offloading happens on a submodule basis. Memory savings are higher than with
420420
`enable_model_cpu_offload`, but performance is lower.
421421
"""
422-
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
423-
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
422+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
423+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
424424

425425
def progress_bar(self, iterable=None, total=None):
426426
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -652,16 +652,16 @@ def __init__(
652652
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
653653
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
654654

655-
def enable_sequential_cpu_offload(self, gpu_id=0):
655+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
656656
r"""
657657
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
658658
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
659659
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
660660
Note that offloading happens on a submodule basis. Memory savings are higher than with
661661
`enable_model_cpu_offload`, but performance is lower.
662662
"""
663-
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
664-
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
663+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
664+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
665665

666666
def progress_bar(self, iterable=None, total=None):
667667
self.prior_pipe.progress_bar(iterable=iterable, total=total)

0 commit comments

Comments
 (0)