@@ -237,7 +237,7 @@ def disable_vae_slicing(self):
237237        """ 
238238        self .vae .disable_slicing ()
239239
240-     def  enable_model_cpu_offload (self , gpu_id = 0 ):
240+     def  enable_model_cpu_offload (self , gpu_id :  Optional [ int ]  =   None ,  device :  Union [ torch . device ,  str ]  =   "cuda" ):
241241        r""" 
242242        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 
243243        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 
@@ -249,11 +249,23 @@ def enable_model_cpu_offload(self, gpu_id=0):
249249        else :
250250            raise  ImportError ("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." )
251251
252-         device  =  torch .device (f"cuda:{ gpu_id }  )
252+         torch_device  =  torch .device (device )
253+         device_index  =  torch_device .index 
254+ 
255+         if  gpu_id  is  not None  and  device_index  is  not None :
256+             raise  ValueError (
257+                 f"You have passed both `gpu_id`={ gpu_id } { device }  
258+                 f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={ torch_device .type }  
259+             )
260+ 
261+         device_type  =  torch_device .type 
262+         device  =  torch .device (f"{ device_type } { gpu_id  or  torch_device .index }  )
253263
254264        if  self .device .type  !=  "cpu" :
255265            self .to ("cpu" , silence_dtype_warnings = True )
256-             torch .cuda .empty_cache ()  # otherwise we don't see the memory savings (but they probably exist) 
266+             device_mod  =  getattr (torch , device .type , None )
267+             if  hasattr (device_mod , "empty_cache" ) and  device_mod .is_available ():
268+                 device_mod .empty_cache ()  # otherwise we don't see the memory savings (but they probably exist) 
257269
258270        model_sequence  =  [
259271            self .text_encoder .text_model ,
0 commit comments