@@ -193,15 +193,15 @@ def __init__(
193
193
def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
194
194
self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
195
195
196
- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
196
+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
197
197
r"""
198
198
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
199
199
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
200
200
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
201
201
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
202
202
"""
203
- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
204
- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
203
+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
204
+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
205
205
206
206
def progress_bar (self , iterable = None , total = None ):
207
207
self .prior_pipe .progress_bar (iterable = iterable , total = total )
@@ -411,16 +411,16 @@ def __init__(
411
411
def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
412
412
self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
413
413
414
- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
414
+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
415
415
r"""
416
416
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
417
417
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
418
418
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
419
419
Note that offloading happens on a submodule basis. Memory savings are higher than with
420
420
`enable_model_cpu_offload`, but performance is lower.
421
421
"""
422
- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
423
- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
422
+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
423
+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
424
424
425
425
def progress_bar (self , iterable = None , total = None ):
426
426
self .prior_pipe .progress_bar (iterable = iterable , total = total )
@@ -652,16 +652,16 @@ def __init__(
652
652
def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
653
653
self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
654
654
655
- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
655
+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
656
656
r"""
657
657
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
658
658
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
659
659
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
660
660
Note that offloading happens on a submodule basis. Memory savings are higher than with
661
661
`enable_model_cpu_offload`, but performance is lower.
662
662
"""
663
- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
664
- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
663
+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
664
+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
665
665
666
666
def progress_bar (self , iterable = None , total = None ):
667
667
self .prior_pipe .progress_bar (iterable = iterable , total = total )
0 commit comments