Skip to content

Commit 32f9900

Browse files
authored
Merge pull request #407 from AInVFX/main
v2.5.21: fix GGUF dequant regression, MPS performance optimizations
2 parents a1486a3 + 84abef8 commit 32f9900

File tree

9 files changed

+75
-26
lines changed

9 files changed

+75
-26
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ We're actively working on improvements and new features. To stay informed:
3636

3737
## 🚀 Release Notes
3838

39+
**2025.12.12 - Version 2.5.21**
40+
41+
- **🛠️ Fix: GGUF dequantization error on MPS** - Resolved shape mismatch error introduced in 2.5.20 by skipping GGUF quantized buffers in precision conversion - these must remain in packed format for on-the-fly dequantization during inference
42+
- **🍎 MPS: Eliminate CPU sync overhead** - Skip unnecessary CPU tensor offload on Apple Silicon unified memory architecture, preventing sync stalls that caused slowdowns. Input images and output video now stay on MPS device throughout the pipeline
43+
- **⚡ MPS: Preload text embeddings** - Load text embeddings before Phase 1 encoding to avoid sync stall at Phase 2 start, improving timing accuracy and throughput
44+
- **🧹 MPS: Optimized model cleanup** - Skip redundant CPU movement before model deletion on unified memory
45+
3946
**2025.12.12 - Version 2.5.20**
4047

4148
- **⚡ Expanded attention backends** - Full support for Flash Attention 2 (Ampere+), Flash Attention 3 (Hopper+), SageAttention 2, and SageAttention 3 (Blackwell/RTX 50xx), with automatic fallback chains to PyTorch SDPA when unavailable *(based on PR by [@naxci1](https://github.com/naxci1) - thank you!)*

inference_cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@
118118
prepare_runner,
119119
compute_generation_info,
120120
log_generation_start,
121-
blend_overlapping_frames
121+
blend_overlapping_frames,
122+
load_text_embeddings,
123+
script_directory
122124
)
123125
from src.core.generation_phases import (
124126
encode_all_batches,
@@ -858,6 +860,10 @@ def _process_frames_core(
858860
if runner_cache is not None:
859861
runner_cache['runner'] = runner
860862

863+
# Preload text embeddings before Phase 1 to avoid sync stall in Phase 2
864+
ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug)
865+
debug.log("Loaded text embeddings for DiT", category="dit")
866+
861867
# Compute generation info and log start (handles prepending internally)
862868
frames_tensor, gen_info = compute_generation_info(
863869
ctx=ctx,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "seedvr2_videoupscaler"
33
description = "SeedVR2 official ComfyUI integration: ByteDance-Seed's one-step diffusion-based video/image upscaling with memory-efficient inference"
4-
version = "2.5.20"
4+
version = "2.5.21"
55
authors = [
66
{name = "numz"},
77
{name = "adrientoupet"}

src/core/generation_phases.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ def encode_all_batches(
231231
if images is None:
232232
raise ValueError("Images to encode must be provided")
233233
else:
234-
ctx['input_images'] = images
234+
# MPS: keep on device to avoid sync overhead in Phase 4 color correction
235+
if ctx['vae_device'].type == 'mps' and images.device.type != 'mps':
236+
ctx['input_images'] = images.to(ctx['vae_device'])
237+
else:
238+
ctx['input_images'] = images
235239

236240
# Get total frame count from context (set in video_upscaler before encoding)
237241
total_frames = ctx.get('total_frames', len(images))
@@ -529,6 +533,10 @@ def encode_all_batches(
529533
manage_model_device(model=runner.vae, target_device=ctx['vae_offload_device'],
530534
model_name="VAE", debug=debug, reason="VAE offload", runner=runner)
531535

536+
# MPS: sync to get accurate timing and free memory before Phase 2
537+
if ctx['vae_device'].type == 'mps':
538+
torch.mps.synchronize()
539+
532540
debug.end_timer("phase1_encoding", "Phase 1: VAE encoding complete", show_breakdown=True)
533541
debug.log_memory_state("After phase 1 (VAE encoding)", show_tensors=False)
534542

@@ -860,7 +868,13 @@ def decode_all_batches(
860868

861869
# Pre-allocate final_video at the START of decode phase (before any batch processing)
862870
# This ensures we only need memory for final_video + 1 batch, not final_video + all batch_samples
863-
target_device = ctx['tensor_offload_device'] if ctx['tensor_offload_device'] is not None else 'cpu'
871+
# MPS: keep on device (unified memory, no benefit to CPU offload)
872+
if ctx['tensor_offload_device'] is not None:
873+
target_device = ctx['tensor_offload_device']
874+
elif ctx['vae_device'].type == 'mps':
875+
target_device = ctx['vae_device']
876+
else:
877+
target_device = 'cpu'
864878
channels_str = "RGBA" if C == 4 else "RGB"
865879
required_gb = (total_frames * true_h * true_w * C * 2) / (1024**3)
866880
debug.log(f"Pre-allocating output tensor: {total_frames} frames, {true_w}x{true_h}px, {channels_str} ({required_gb:.2f}GB)",
@@ -1040,6 +1054,10 @@ def decode_all_batches(
10401054
if 'all_upscaled_latents' in ctx:
10411055
release_tensor_collection(ctx['all_upscaled_latents'])
10421056
del ctx['all_upscaled_latents']
1057+
1058+
# MPS: sync to get accurate timing and free memory before Phase 4
1059+
if ctx['vae_device'].type == 'mps':
1060+
torch.mps.synchronize()
10431061

10441062
debug.end_timer("phase3_decoding", "Phase 3: VAE decoding complete", show_breakdown=True)
10451063
debug.log_memory_state("After phase 3 (VAE decoding)", show_tensors=False)

src/core/generation_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,12 @@ def _normalize_device(device_spec: Optional[Union[str, torch.device]]) -> torch.
350350
vae_device = _normalize_device(vae_device)
351351
dit_offload_device = _normalize_device(dit_offload_device) if dit_offload_device is not None else None
352352
vae_offload_device = _normalize_device(vae_offload_device) if vae_offload_device is not None else None
353-
tensor_offload_device = _normalize_device(tensor_offload_device) if tensor_offload_device is not None else None
353+
# MPS unified memory: CPU offload causes sync overhead with no memory benefit
354+
is_mps = dit_device.type == 'mps' or vae_device.type == 'mps'
355+
if is_mps and tensor_offload_device is not None and str(tensor_offload_device) == 'cpu':
356+
tensor_offload_device = None
357+
else:
358+
tensor_offload_device = _normalize_device(tensor_offload_device) if tensor_offload_device is not None else None
354359

355360
# Set LOCAL_RANK to 0 for single-GPU inference mode
356361
# CLI multi-GPU uses CUDA_VISIBLE_DEVICES to restrict visibility per worker

src/interfaces/video_upscaler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
setup_generation_context,
2020
prepare_runner,
2121
compute_generation_info,
22-
log_generation_start
22+
log_generation_start,
23+
load_text_embeddings,
24+
script_directory
2325
)
2426
from ..optimization.memory_manager import (
2527
cleanup_text_embeddings,
@@ -437,6 +439,10 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None:
437439
# Store cache context in ctx for use in generation phases
438440
ctx['cache_context'] = cache_context
439441

442+
# Preload text embeddings before Phase 1 to avoid sync stall in Phase 2
443+
ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug)
444+
debug.log("Loaded text embeddings for DiT", category="dit")
445+
440446
debug.log_memory_state("After model preparation", show_tensors=False, detailed_tensors=False)
441447
debug.end_timer("model_preparation", "Model preparation", force=True, show_breakdown=True)
442448

src/optimization/compatibility.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,11 @@ def _force_nadit_precision(self, target_dtype: torch.dtype = torch.bfloat16) ->
826826
param.data = param.data.to(target_dtype)
827827
converted_count += 1
828828

829-
# Also convert buffers
829+
# Also convert buffers (skip GGUF quantized buffers - they have tensor_type attribute)
830830
for name, buffer in self.dit_model.named_buffers():
831+
# Skip GGUF quantized buffers - these must stay in packed format for on-the-fly dequantization
832+
if hasattr(buffer, 'tensor_type'):
833+
continue
831834
if buffer.dtype != target_dtype:
832835
if buffer.device.type == "mps":
833836
temp_cpu = buffer.data.to("cpu")

src/optimization/memory_manager.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,15 +1050,17 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool
10501050

10511051
# Move model off GPU if needed
10521052
if param_device.type not in ['meta', 'cpu']:
1053-
# Get offload target - default to 'cpu' if not configured or set to 'none'
1054-
offload_target = getattr(runner, '_dit_offload_device', None)
1055-
if offload_target is None or offload_target == 'none':
1056-
offload_target = torch.device('cpu')
1057-
1058-
# Move model off GPU (either for caching or before deletion)
1059-
reason = "model caching" if cache_model else "releasing GPU memory"
1060-
manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT",
1061-
debug=debug, reason=reason, runner=runner)
1053+
# MPS: skip CPU movement before deletion (unified memory, just causes sync)
1054+
if param_device.type == 'mps' and not cache_model:
1055+
if debug:
1056+
debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup")
1057+
else:
1058+
offload_target = getattr(runner, '_dit_offload_device', None)
1059+
if offload_target is None or offload_target == 'none':
1060+
offload_target = torch.device('cpu')
1061+
reason = "model caching" if cache_model else "releasing GPU memory"
1062+
manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT",
1063+
debug=debug, reason=reason, runner=runner)
10621064
elif param_device.type == 'meta' and debug:
10631065
debug.log("DiT on meta device - keeping structure for cache", category="cleanup")
10641066
except StopIteration:
@@ -1126,15 +1128,17 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool
11261128

11271129
# Move model off GPU if needed
11281130
if param_device.type not in ['meta', 'cpu']:
1129-
# Get offload target - default to 'cpu' if not configured or set to 'none'
1130-
offload_target = getattr(runner, '_vae_offload_device', None)
1131-
if offload_target is None or offload_target == 'none':
1132-
offload_target = torch.device('cpu')
1133-
1134-
# Move model off GPU (either for caching or before deletion)
1135-
reason = "model caching" if cache_model else "releasing GPU memory"
1136-
manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE",
1137-
debug=debug, reason=reason, runner=runner)
1131+
# MPS: skip CPU movement before deletion (unified memory, just causes sync)
1132+
if param_device.type == 'mps' and not cache_model:
1133+
if debug:
1134+
debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup")
1135+
else:
1136+
offload_target = getattr(runner, '_vae_offload_device', None)
1137+
if offload_target is None or offload_target == 'none':
1138+
offload_target = torch.device('cpu')
1139+
reason = "model caching" if cache_model else "releasing GPU memory"
1140+
manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE",
1141+
debug=debug, reason=reason, runner=runner)
11381142
elif param_device.type == 'meta' and debug:
11391143
debug.log("VAE on meta device - keeping structure for cache", category="cleanup")
11401144
except StopIteration:

src/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# Version information
7-
__version__ = "2.5.20"
7+
__version__ = "2.5.21"
88

99
import os
1010
import warnings

0 commit comments

Comments
 (0)