Skip to content

Commit d4dd5e7

Browse files
authored
Merge pull request #344 from AInVFX/main
v2.5.14: MPS device fix, VRAM swap detection, enforce physical VRAM limit
2 parents f5b902b + e2faeda commit d4dd5e7

File tree

5 files changed

+69
-20
lines changed

5 files changed

+69
-20
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ We're actively working on improvements and new features. To stay informed:
3636

3737
## 🚀 Updates
3838

39+
**2025.12.01 - Version 2.5.14**
40+
41+
- **🍎 Fix: MPS device comparison** - Normalize device strings to prevent unnecessary tensor movements
42+
- **📊 Memory: VRAM swap detection** - Peak stats now show GPU+swap breakdown when overflow occurs, with warning when swap detected
43+
- **🛡️ Memory: Enforce physical VRAM limit** - PyTorch now OOMs instead of silently swapping to shared memory (prevents extreme slowdowns on Windows)
44+
3945
**2025.11.30 - Version 2.5.13**
4046

4147
- **🔧 Fix: PyTorch 2.7+ triton import error** - Resolved installation crash caused by triton.ops import chain on newer triton versions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "seedvr2_videoupscaler"
33
description = "SeedVR2 official ComfyUI integration: ByteDance-Seed's one-step diffusion-based video/image upscaling with memory-efficient inference"
4-
version = "2.5.13"
4+
version = "2.5.14"
55
authors = [
66
{name = "numz"},
77
{name = "adrientoupet"}

src/optimization/memory_manager.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from typing import Tuple, Dict, Any, Optional, List, Union
1414

1515

16+
def _device_str(device: Union[torch.device, str]) -> str:
17+
"""Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'."""
18+
s = str(device).upper()
19+
return 'MPS' if s.startswith('MPS') else s
20+
21+
1622
def get_device_list(include_none: bool = False, include_cpu: bool = False) -> List[str]:
1723
"""
1824
Get list of available compute devices for SeedVR2
@@ -106,6 +112,22 @@ def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]
106112
print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!")
107113

108114

115+
def _enforce_vram_limit() -> None:
116+
"""
117+
Enforce VRAM limit to physical capacity to prevent silent swap to system RAM.
118+
Called once at module load. No-op on MPS or unsupported platforms.
119+
"""
120+
if not torch.cuda.is_available():
121+
return
122+
try:
123+
for i in range(torch.cuda.device_count()):
124+
torch.cuda.set_per_process_memory_fraction(1.0, i)
125+
except Exception:
126+
pass
127+
128+
_enforce_vram_limit()
129+
130+
109131
def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> Tuple[float, float, float]:
110132
"""
111133
Get current VRAM usage metrics for monitoring.
@@ -591,7 +613,7 @@ def manage_tensor(
591613
target_dtype = dtype if dtype is not None else current_dtype
592614

593615
# Check if movement is actually needed
594-
needs_device_move = current_device != target_device
616+
needs_device_move = _device_str(current_device) != _device_str(target_device)
595617
needs_dtype_change = dtype is not None and current_dtype != target_dtype
596618

597619
if not needs_device_move and not needs_dtype_change:
@@ -609,8 +631,8 @@ def manage_tensor(
609631

610632
# Log the movement
611633
if debug:
612-
current_device_str = str(current_device).upper()
613-
target_device_str = str(target_device).upper()
634+
current_device_str = _device_str(current_device)
635+
target_device_str = _device_str(target_device)
614636

615637
dtype_info = ""
616638
if needs_dtype_change:
@@ -681,8 +703,8 @@ def manage_model_device(model: torch.nn.Module, target_device: torch.device, mod
681703

682704
# Extract device type for comparison (both are torch.device objects)
683705
target_type = target_device.type
684-
current_device_upper = str(current_device).upper()
685-
target_device_upper = str(target_device).upper()
706+
current_device_upper = _device_str(current_device)
707+
target_device_upper = _device_str(target_device)
686708

687709
# Compare normalized device types
688710
if current_device_upper == target_device_upper and not is_blockswap_model:
@@ -737,10 +759,10 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
737759
actual_source_device = param.device
738760
break
739761

740-
source_device_desc = str(actual_source_device).upper() if actual_source_device else str(target_device).upper()
762+
source_device_desc = _device_str(actual_source_device) if actual_source_device else _device_str(target_device)
741763

742764
if debug:
743-
debug.log(f"Moving {model_name} from {source_device_desc} to {str(target_device).upper()} ({reason or 'model caching'})", category="general")
765+
debug.log(f"Moving {model_name} from {source_device_desc} to {_device_str(target_device)} ({reason or 'model caching'})", category="general")
744766

745767
# Enable bypass to allow movement
746768
set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
@@ -755,7 +777,7 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
755777
model.zero_grad(set_to_none=True)
756778

757779
if debug:
758-
debug.end_timer(timer_name, f"BlockSwap model offloaded to {str(target_device).upper()}")
780+
debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}")
759781

760782
return True
761783

@@ -775,10 +797,10 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
775797
actual_current_device = param.device
776798
break
777799

778-
current_device_desc = str(actual_current_device).upper() if actual_current_device else "OFFLOAD"
800+
current_device_desc = _device_str(actual_current_device) if actual_current_device else "OFFLOAD"
779801

780802
if debug:
781-
debug.log(f"Moving {model_name} from {current_device_desc} to {str(target_device).upper()} ({reason or 'inference requirement'})", category="general")
803+
debug.log(f"Moving {model_name} from {current_device_desc} to {_device_str(target_device)} ({reason or 'inference requirement'})", category="general")
782804

783805
timer_name = f"{model_name.lower()}_to_gpu"
784806
if debug:
@@ -818,7 +840,7 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module,
818840
blocks_on_gpu = model._block_swap_config.get('total_blocks', 32) - model._block_swap_config.get('blocks_swapped', 16)
819841
total_blocks = model._block_swap_config.get('total_blocks', 32)
820842
main_device = model._block_swap_config.get('main_device', 'GPU')
821-
debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {str(main_device).upper()})", category="success")
843+
debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {_device_str(main_device)})", category="success")
822844
else:
823845
debug.log("BlockSwap blocks restored to configured devices", category="success")
824846

@@ -865,8 +887,8 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic
865887

866888
# Log the movement with full device strings
867889
if debug:
868-
current_device_str = str(current_device).upper()
869-
target_device_str = str(target_device).upper()
890+
current_device_str = _device_str(current_device)
891+
target_device_str = _device_str(target_device)
870892
debug.log(f"Moving {model_name} from {current_device_str} to {target_device_str} ({reason})", category="general")
871893

872894
# Start timer based on direction
@@ -891,7 +913,7 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic
891913

892914
# End timer
893915
if debug:
894-
debug.end_timer(timer_name, f"{model_name} moved to {str(target_device).upper()}")
916+
debug.end_timer(timer_name, f"{model_name} moved to {_device_str(target_device)}")
895917

896918
return True
897919

src/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# Version information
7-
__version__ = "2.5.13"
7+
__version__ = "2.5.14"
88

99
import os
1010
import warnings

src/utils/debug.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
from ..utils.constants import __version__
1515

1616

17+
def _format_peak_with_swap(peak_gb: float, total_vram_gb: float) -> str:
18+
"""Format peak memory, showing swap breakdown if overflow occurred."""
19+
if total_vram_gb > 0 and peak_gb > total_vram_gb:
20+
swap_gb = peak_gb - total_vram_gb
21+
return f"{peak_gb:.2f}GB ({total_vram_gb:.0f}GB GPU + {swap_gb:.2f}GB swap)"
22+
return f"{peak_gb:.2f}GB"
23+
24+
1725
class Debug:
1826
"""
1927
Unified debug logging for generation pipeline and BlockSwap monitoring
@@ -307,7 +315,12 @@ def log_memory_state(self, label: str, show_diff: bool = True, show_tensors: boo
307315
if show_diff and self.memory_checkpoints:
308316
self._log_memory_diff(current_metrics=memory_info, force=force)
309317

310-
# Log detailed analysis if requested
318+
# Warn if swap detected (peak > physical VRAM)
319+
if memory_info['vram_total'] > 0 and memory_info['vram_peak_since_last'] > memory_info['vram_total']:
320+
self.log("VRAM swap detected - severe slowdown expected. Consider optimizing (e.g., reduce resolution, batch_size, enable BlockSwap, VAE tiling...).",
321+
level="WARNING", category="memory", force=True)
322+
323+
# Log detailed analysis if requested
311324
if detailed_tensors and tensor_stats.get('details'):
312325
self._log_detailed_tensor_analysis(details=tensor_stats['details'], force=force)
313326

@@ -361,9 +374,10 @@ def _collect_memory_metrics(self) -> Dict[str, Any]:
361374
metrics['vram_total'] = vram_info["total_gb"]
362375

363376
backend = "MPS" if (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()) else "VRAM"
377+
peak_str = _format_peak_with_swap(metrics['vram_peak_since_last'], metrics['vram_total'])
364378
metrics['summary_vram'] = (f" [{backend}] {metrics['vram_allocated']:.2f}GB allocated / "
365379
f"{metrics['vram_reserved']:.2f}GB reserved / "
366-
f"Peak: {metrics['vram_peak_since_last']:.2f}GB / "
380+
f"Peak: {peak_str} / "
367381
f"{metrics['vram_free']:.2f}GB free / "
368382
f"{metrics['vram_total']:.2f}GB total")
369383
else:
@@ -525,6 +539,13 @@ def log_peak_memory_summary(self, force: bool = True) -> None:
525539

526540
is_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and not torch.cuda.is_available()
527541

542+
# Get total VRAM for swap detection (reuse existing function)
543+
total_vram_gb = 0.0
544+
if not is_mps:
545+
vram_info = get_basic_vram_info(device=None)
546+
if "error" not in vram_info:
547+
total_vram_gb = vram_info["total_gb"]
548+
528549
self.log("", category="none", force=force)
529550
self.log("────────────────────────", category="none", force=force)
530551
self.log("Peak memory by phase:", category="memory", force=force)
@@ -539,15 +560,15 @@ def log_peak_memory_summary(self, force: bool = True) -> None:
539560
if is_mps:
540561
self.log(f" Phase {phase_num} ({phase_name}): {vram:.2f}GB", category="memory", force=force)
541562
else:
542-
self.log(f" Phase {phase_num} ({phase_name}): VRAM {vram:.2f}GB | RAM {ram:.2f}GB", category="memory", force=force)
563+
self.log(f" Phase {phase_num} ({phase_name}): {_format_peak_with_swap(vram, total_vram_gb)} | RAM {ram:.2f}GB", category="memory", force=force)
543564

544565
if is_mps:
545566
overall = max(self.phase_vram_peaks.values()) if self.phase_vram_peaks else 0
546567
self.log(f"Overall Peak: {overall:.2f}GB", category="memory", force=force)
547568
else:
548569
overall_vram = max(self.phase_vram_peaks.values()) if self.phase_vram_peaks else 0
549570
overall_ram = max(self.phase_ram_peaks.values()) if self.phase_ram_peaks else 0
550-
self.log(f"Overall peak: VRAM {overall_vram:.2f}GB | RAM {overall_ram:.2f}GB", category="memory", force=force)
571+
self.log(f"Overall peak: {_format_peak_with_swap(overall_vram, total_vram_gb)} | RAM {overall_ram:.2f}GB", category="memory", force=force)
551572

552573
@torch._dynamo.disable # Skip tracing to avoid time.time() warnings
553574
def _store_checkpoint(self, label: str, metrics: Dict[str, Any]) -> None:

0 commit comments

Comments
 (0)