@@ -193,15 +193,15 @@ def __init__(
193193 def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
194194 self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
195195
196- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
196+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
197197 r"""
198198 Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
199199 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
200200 GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
201201 Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
202202 """
203- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
204- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
203+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
204+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
205205
206206 def progress_bar (self , iterable = None , total = None ):
207207 self .prior_pipe .progress_bar (iterable = iterable , total = total )
@@ -411,16 +411,16 @@ def __init__(
411411 def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
412412 self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
413413
414- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
414+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
415415 r"""
416416 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
417417 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
418418 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
419419 Note that offloading happens on a submodule basis. Memory savings are higher than with
420420 `enable_model_cpu_offload`, but performance is lower.
421421 """
422- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
423- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
422+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
423+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
424424
425425 def progress_bar (self , iterable = None , total = None ):
426426 self .prior_pipe .progress_bar (iterable = iterable , total = total )
@@ -652,16 +652,16 @@ def __init__(
652652 def enable_xformers_memory_efficient_attention (self , attention_op : Optional [Callable ] = None ):
653653 self .decoder_pipe .enable_xformers_memory_efficient_attention (attention_op )
654654
655- def enable_sequential_cpu_offload (self , gpu_id = 0 ):
655+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
656656 r"""
657657 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
658658 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
659659 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
660660 Note that offloading happens on a submodule basis. Memory savings are higher than with
661661 `enable_model_cpu_offload`, but performance is lower.
662662 """
663- self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
664- self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id )
663+ self .prior_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
664+ self .decoder_pipe .enable_sequential_cpu_offload (gpu_id = gpu_id , device = device )
665665
666666 def progress_bar (self , iterable = None , total = None ):
667667 self .prior_pipe .progress_bar (iterable = iterable , total = total )
0 commit comments