6565 numpy_to_pil ,
6666)
6767from ..utils .hub_utils import _check_legacy_sharding_variant_format , load_or_create_model_card , populate_model_card
68- from ..utils .torch_utils import is_compiled_module
68+ from ..utils .torch_utils import is_compiled_module , get_device
6969
7070
7171if is_torch_npu_available ():
@@ -1084,19 +1084,19 @@ def remove_all_hooks(self):
10841084 accelerate .hooks .remove_hook_from_module (model , recurse = True )
10851085 self ._all_hooks = []
10861086
1087- def enable_model_cpu_offload (self , gpu_id : Optional [int ] = None , device : Union [torch .device , str ] = "cuda" ):
1087+ def enable_model_cpu_offload (self , gpu_id : Optional [int ] = None , device : Union [torch .device , str ] = None ):
10881088 r"""
10891089 Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
1090- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
1091- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
1090+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its `forward`
1091+ method is called, and the model remains in accelerator until the next model runs. Memory savings are lower than with
10921092 `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
10931093
10941094 Arguments:
10951095 gpu_id (`int`, *optional*):
10961096 The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1097- device (`torch.Device` or `str`, *optional*, defaults to "cuda" ):
1097+ device (`torch.Device` or `str`, *optional*, defaults to None ):
10981098 The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1099- default to "cuda" .
1099+ automatically detect the available accelerator and use .
11001100 """
11011101 self ._maybe_raise_error_if_group_offload_active (raise_error = True )
11021102
@@ -1118,6 +1118,11 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
11181118
11191119 self .remove_all_hooks ()
11201120
1121+ if device is None :
1122+ device = get_device ()
1123+ if device == "cpu" :
1124+ raise RuntimeError ("`enable_model_cpu_offload` requires accelerator, but not found" )
1125+
11211126 torch_device = torch .device (device )
11221127 device_index = torch_device .index
11231128
@@ -1196,7 +1201,7 @@ def maybe_free_model_hooks(self):
11961201 # make sure the model is in the same state as before calling it
11971202 self .enable_model_cpu_offload (device = getattr (self , "_offload_device" , "cuda" ))
11981203
1199- def enable_sequential_cpu_offload (self , gpu_id : Optional [int ] = None , device : Union [torch .device , str ] = "cuda" ):
1204+ def enable_sequential_cpu_offload (self , gpu_id : Optional [int ] = None , device : Union [torch .device , str ] = None ):
12001205 r"""
12011206 Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
12021207 dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
@@ -1207,9 +1212,9 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
12071212 Arguments:
12081213 gpu_id (`int`, *optional*):
12091214 The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1210- device (`torch.Device` or `str`, *optional*, defaults to "cuda" ):
1215+ device (`torch.Device` or `str`, *optional*, defaults to None ):
12111216 The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1212- default to "cuda" .
1217+ automatically detect the available accelerator and use .
12131218 """
12141219 self ._maybe_raise_error_if_group_offload_active (raise_error = True )
12151220
@@ -1225,6 +1230,11 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
12251230 "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
12261231 )
12271232
1233+ if device is None :
1234+ device = get_device ()
1235+ if device == "cpu" :
1236+ raise RuntimeError ("`enable_sequential_cpu_offload` requires accelerator, but not found" )
1237+
12281238 torch_device = torch .device (device )
12291239 device_index = torch_device .index
12301240
0 commit comments