@@ -237,7 +237,7 @@ def disable_vae_slicing(self):
237237 """
238238 self .vae .disable_slicing ()
239239
240- def enable_model_cpu_offload (self , gpu_id = 0 ):
240+ def enable_model_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [ torch . device , str ] = "cuda" ):
241241 r"""
242242 Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
243243 to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -249,11 +249,23 @@ def enable_model_cpu_offload(self, gpu_id=0):
249249 else :
250250 raise ImportError ("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." )
251251
252- device = torch .device (f"cuda:{ gpu_id } " )
252+ torch_device = torch .device (device )
253+ device_index = torch_device .index
254+
255+ if gpu_id is not None and device_index is not None :
256+ raise ValueError (
257+ f"You have passed both `gpu_id`={ gpu_id } and an index as part of the passed device `device`={ device } "
258+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={ torch_device .type } "
259+ )
260+
261+ device_type = torch_device .type
262+ device = torch .device (f"{ device_type } :{ gpu_id or torch_device .index } " )
253263
254264 if self .device .type != "cpu" :
255265 self .to ("cpu" , silence_dtype_warnings = True )
256- torch .cuda .empty_cache () # otherwise we don't see the memory savings (but they probably exist)
266+ device_mod = getattr (torch , device .type , None )
267+ if hasattr (device_mod , "empty_cache" ) and device_mod .is_available ():
268+ device_mod .empty_cache () # otherwise we don't see the memory savings (but they probably exist)
257269
258270 model_sequence = [
259271 self .text_encoder .text_model ,
0 commit comments