2525 is_accelerate_available ,
2626 logging ,
2727)
28+ from ..utils .torch_utils import get_device
2829
2930
3031if is_accelerate_available ():
@@ -161,7 +162,9 @@ def __call__(self, hooks, model_id, model, execution_device):
161162
162163 current_module_size = model .get_memory_footprint ()
163164
164- mem_on_device = torch .cuda .mem_get_info (execution_device .index )[0 ]
165+ device_type = execution_device .type
166+ device_module = getattr (torch , device_type , torch .cuda )
167+ mem_on_device = device_module .mem_get_info (execution_device .index )[0 ]
165168 mem_on_device = mem_on_device - self .memory_reserve_margin
166169 if current_module_size < mem_on_device :
167170 return []
@@ -301,7 +304,7 @@ class ComponentsManager:
301304 cm.add("vae", vae_model, collection="sdxl")
302305
303306 # Enable auto offloading
304- cm.enable_auto_cpu_offload(device="cuda" )
307+ cm.enable_auto_cpu_offload()
305308
306309 # Retrieve components
307310 unet = cm.get_one(name="unet", collection="sdxl")
@@ -490,6 +493,8 @@ def remove(self, component_id: str = None):
490493 gc .collect ()
491494 if torch .cuda .is_available ():
492495 torch .cuda .empty_cache ()
496+ if torch .xpu .is_available ():
497+ torch .xpu .empty_cache ()
493498
494499 # YiYi TODO: rename to search_components for now, may remove this method
495500 def search_components (
@@ -678,7 +683,7 @@ def matches_pattern(component_id, pattern, exact_match=False):
678683
679684 return get_return_dict (matches , return_dict_with_names )
680685
681- def enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] = "cuda" , memory_reserve_margin = "3GB" ):
686+ def enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] = None , memory_reserve_margin = "3GB" ):
682687 """
683688 Enable automatic CPU offloading for all components.
684689
@@ -704,6 +709,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
704709
705710 self .disable_auto_cpu_offload ()
706711 offload_strategy = AutoOffloadStrategy (memory_reserve_margin = memory_reserve_margin )
712+ if device is None :
713+ device = get_device ()
707714 device = torch .device (device )
708715 if device .index is None :
709716 device = torch .device (f"{ device .type } :{ 0 } " )
0 commit comments