1313from typing import Tuple , Dict , Any , Optional , List , Union
1414
1515
16+ def _device_str (device : Union [torch .device , str ]) -> str :
17+ """Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'."""
18+ s = str (device ).upper ()
19+ return 'MPS' if s .startswith ('MPS' ) else s
20+
21+
1622def get_device_list (include_none : bool = False , include_cpu : bool = False ) -> List [str ]:
1723 """
1824 Get list of available compute devices for SeedVR2
@@ -106,6 +112,22 @@ def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]
106112 print (f"⚠️ Memory check failed: { vram_info ['error' ]} - No available backend!" )
107113
108114
115+ def _enforce_vram_limit () -> None :
116+ """
117+ Enforce VRAM limit to physical capacity to prevent silent swap to system RAM.
118+ Called once at module load. No-op on MPS or unsupported platforms.
119+ """
120+ if not torch .cuda .is_available ():
121+ return
122+ try :
123+ for i in range (torch .cuda .device_count ()):
124+ torch .cuda .set_per_process_memory_fraction (1.0 , i )
125+ except Exception :
126+ pass
127+
128+ _enforce_vram_limit ()
129+
130+
109131def get_vram_usage (device : Optional [torch .device ] = None , debug : Optional ['Debug' ] = None ) -> Tuple [float , float , float ]:
110132 """
111133 Get current VRAM usage metrics for monitoring.
@@ -591,7 +613,7 @@ def manage_tensor(
591613 target_dtype = dtype if dtype is not None else current_dtype
592614
593615 # Check if movement is actually needed
594- needs_device_move = current_device != target_device
616+ needs_device_move = _device_str ( current_device ) != _device_str ( target_device )
595617 needs_dtype_change = dtype is not None and current_dtype != target_dtype
596618
597619 if not needs_device_move and not needs_dtype_change :
@@ -609,8 +631,8 @@ def manage_tensor(
609631
610632 # Log the movement
611633 if debug :
612- current_device_str = str (current_device ). upper ( )
613- target_device_str = str (target_device ). upper ( )
634+ current_device_str = _device_str (current_device )
635+ target_device_str = _device_str (target_device )
614636
615637 dtype_info = ""
616638 if needs_dtype_change :
@@ -681,8 +703,8 @@ def manage_model_device(model: torch.nn.Module, target_device: torch.device, mod
681703
682704 # Extract device type for comparison (both are torch.device objects)
683705 target_type = target_device .type
684- current_device_upper = str (current_device ). upper ( )
685- target_device_upper = str (target_device ). upper ( )
706+ current_device_upper = _device_str (current_device )
707+ target_device_upper = _device_str (target_device )
686708
687709 # Compare normalized device types
688710 if current_device_upper == target_device_upper and not is_blockswap_model :
@@ -737,10 +759,10 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
737759 actual_source_device = param .device
738760 break
739761
740- source_device_desc = str (actual_source_device ). upper () if actual_source_device else str (target_device ). upper ( )
762+ source_device_desc = _device_str (actual_source_device ) if actual_source_device else _device_str (target_device )
741763
742764 if debug :
743- debug .log (f"Moving { model_name } from { source_device_desc } to { str (target_device ). upper ( )} ({ reason or 'model caching' } )" , category = "general" )
765+ debug .log (f"Moving { model_name } from { source_device_desc } to { _device_str (target_device )} ({ reason or 'model caching' } )" , category = "general" )
744766
745767 # Enable bypass to allow movement
746768 set_blockswap_bypass (runner = runner , bypass = True , debug = debug )
@@ -755,7 +777,7 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
755777 model .zero_grad (set_to_none = True )
756778
757779 if debug :
758- debug .end_timer (timer_name , f"BlockSwap model offloaded to { str (target_device ). upper ( )} " )
780+ debug .end_timer (timer_name , f"BlockSwap model offloaded to { _device_str (target_device )} " )
759781
760782 return True
761783
@@ -775,10 +797,10 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
775797 actual_current_device = param .device
776798 break
777799
778- current_device_desc = str (actual_current_device ). upper ( ) if actual_current_device else "OFFLOAD"
800+ current_device_desc = _device_str (actual_current_device ) if actual_current_device else "OFFLOAD"
779801
780802 if debug :
781- debug .log (f"Moving { model_name } from { current_device_desc } to { str (target_device ). upper ( )} ({ reason or 'inference requirement' } )" , category = "general" )
803+ debug .log (f"Moving { model_name } from { current_device_desc } to { _device_str (target_device )} ({ reason or 'inference requirement' } )" , category = "general" )
782804
783805 timer_name = f"{ model_name .lower ()} _to_gpu"
784806 if debug :
@@ -818,7 +840,7 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
818840 blocks_on_gpu = model ._block_swap_config .get ('total_blocks' , 32 ) - model ._block_swap_config .get ('blocks_swapped' , 16 )
819841 total_blocks = model ._block_swap_config .get ('total_blocks' , 32 )
820842 main_device = model ._block_swap_config .get ('main_device' , 'GPU' )
821- debug .log (f"BlockSwap blocks restored to configured devices ({ blocks_on_gpu } /{ total_blocks } blocks on { str (main_device ). upper ( )} )" , category = "success" )
843+ debug .log (f"BlockSwap blocks restored to configured devices ({ blocks_on_gpu } /{ total_blocks } blocks on { _device_str (main_device )} )" , category = "success" )
822844 else :
823845 debug .log ("BlockSwap blocks restored to configured devices" , category = "success" )
824846
@@ -865,8 +887,8 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic
865887
866888 # Log the movement with full device strings
867889 if debug :
868- current_device_str = str (current_device ). upper ( )
869- target_device_str = str (target_device ). upper ( )
890+ current_device_str = _device_str (current_device )
891+ target_device_str = _device_str (target_device )
870892 debug .log (f"Moving { model_name } from { current_device_str } to { target_device_str } ({ reason } )" , category = "general" )
871893
872894 # Start timer based on direction
@@ -891,7 +913,7 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic
891913
892914 # End timer
893915 if debug :
894- debug .end_timer (timer_name , f"{ model_name } moved to { str (target_device ). upper ( )} " )
916+ debug .end_timer (timer_name , f"{ model_name } moved to { _device_str (target_device )} " )
895917
896918 return True
897919
0 commit comments