Skip to content

Commit 58bc9e8

Browse files
authored
Merge pull request #373 from AInVFX/main
v2.5.17: Proper bf16 detection for older GPUs #314
2 parents 0a66006 + 3eec584 commit 58bc9e8

File tree

5 files changed

+32
-37
lines changed

5 files changed

+32
-37
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ We're actively working on improvements and new features. To stay informed:
3636

3737
## 🚀 Updates
3838

39+
**2025.12.05 - Version 2.5.17**
40+
41+
- **🔧 Fix: Older GPU compatibility (GTX 970, etc.)** - Runtime bf16 CUBLAS probe replaces compute capability heuristics, correctly detecting unsupported GPUs without affecting RTX 20XX
42+
3943
**2025.12.05 - Version 2.5.16**
4044

4145
- **🔧 Fix: Older GPU compatibility (GTX 970, etc.)** - Automatic fallback for GPUs without bfloat16 support

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "seedvr2_videoupscaler"
33
description = "SeedVR2 official ComfyUI integration: ByteDance-Seed's one-step diffusion-based video/image upscaling with memory-efficient inference"
4-
version = "2.5.16"
4+
version = "2.5.17"
55
authors = [
66
{name = "numz"},
77
{name = "adrientoupet"}

src/core/generation_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .infer import VideoDiffusionInfer
3737
from ..data.image.transforms.divisible_crop import DivisiblePad
3838
from ..data.image.transforms.na_resize import NaResize
39+
from ..optimization.compatibility import COMPUTE_DTYPE, BFLOAT16_SUPPORTED
3940
from ..optimization.memory_manager import manage_tensor
4041
from ..utils.constants import get_script_directory
4142

@@ -371,7 +372,7 @@ def _normalize_device(device_spec: Optional[Union[str, torch.device]]) -> torch.
371372
'dit_offload_device': dit_offload_device,
372373
'vae_offload_device': vae_offload_device,
373374
'tensor_offload_device': tensor_offload_device,
374-
'compute_dtype': torch.bfloat16, # Hardcoded - gives the best compromise between memory & quality without artifacts
375+
'compute_dtype': COMPUTE_DTYPE,
375376
'interrupt_fn': interrupt_fn,
376377
'video_transform': None,
377378
'text_embeds': None,
@@ -401,7 +402,12 @@ def _normalize_device(device_spec: Optional[Union[str, torch.device]]) -> torch.
401402
f"LOCAL_RANK={os.environ['LOCAL_RANK']}",
402403
category="setup"
403404
)
404-
reason = "quality" if ctx['compute_dtype'] == torch.float32 else "compatibility"
405+
if ctx['compute_dtype'] == torch.float32:
406+
reason = "quality"
407+
elif not BFLOAT16_SUPPORTED:
408+
reason = "compatibility (GPU lacks bfloat16 CUBLAS - 7B models unsupported, 3B may have artifacts)"
409+
else:
410+
reason = "performance"
405411
debug.log(f"Unified compute dtype: {ctx['compute_dtype']} across entire pipeline for maximum {reason}", category="precision")
406412

407413
return ctx

src/optimization/compatibility.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,6 @@ def ensure_triton_compat():
4141
import os
4242

4343

44-
# Automatic bfloat16 SDPA fallback for GPUs that don't support it (e.g., GTX 970)
45-
_BFLOAT16_SDPA_WORKS = None # None=untested, True=works, False=needs float16 fallback
46-
_ORIGINAL_SDPA = torch.nn.functional.scaled_dot_product_attention
47-
48-
def _safe_scaled_dot_product_attention(query, key, value, *args, **kwargs):
49-
"""SDPA wrapper with automatic bfloat16 -> float16 fallback for old GPUs."""
50-
global _BFLOAT16_SDPA_WORKS
51-
52-
original_dtype = query.dtype
53-
54-
# Fast path: already know bfloat16 fails on this GPU
55-
if original_dtype == torch.bfloat16 and _BFLOAT16_SDPA_WORKS is False:
56-
out = _ORIGINAL_SDPA(query.half(), key.half(), value.half(), *args, **kwargs)
57-
return out.to(original_dtype)
58-
59-
try:
60-
out = _ORIGINAL_SDPA(query, key, value, *args, **kwargs)
61-
if _BFLOAT16_SDPA_WORKS is None and original_dtype == torch.bfloat16:
62-
_BFLOAT16_SDPA_WORKS = True
63-
return out
64-
except RuntimeError as e:
65-
if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e) and original_dtype == torch.bfloat16:
66-
_BFLOAT16_SDPA_WORKS = False
67-
print("⚠️ [SeedVR2] GPU does not support bfloat16 SDPA, using float16 fallback. "
68-
"Tiling artifacts or black frames may occur.")
69-
out = _ORIGINAL_SDPA(query.half(), key.half(), value.half(), *args, **kwargs)
70-
return out.to(original_dtype)
71-
raise
72-
73-
# Apply SDPA patch at module load
74-
torch.nn.functional.scaled_dot_product_attention = _safe_scaled_dot_product_attention
75-
76-
7744
# Flash Attention & Triton Compatibility Layer
7845
# 1. Flash Attention - speedup for attention operations
7946
try:
@@ -236,6 +203,24 @@ def _check_conv3d_memory_bug():
236203
print(f"🔧 Conv3d workaround active: PyTorch {torch_ver}, cuDNN {cudnn_ver} (fixing VAE 3x memory bug)")
237204

238205

206+
# Bfloat16 CUBLAS support
207+
def _probe_bfloat16_support() -> bool:
208+
if not torch.cuda.is_available():
209+
return True
210+
try:
211+
a = torch.randn(8, 8, dtype=torch.bfloat16, device='cuda:0')
212+
_ = torch.matmul(a, a)
213+
del a
214+
return True
215+
except RuntimeError as e:
216+
if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e):
217+
return False
218+
raise
219+
220+
BFLOAT16_SUPPORTED = _probe_bfloat16_support()
221+
COMPUTE_DTYPE = torch.bfloat16 if BFLOAT16_SUPPORTED else torch.float16
222+
223+
239224
def call_rope_with_stability(method, *args, **kwargs):
240225
"""
241226
Call RoPE method with stability fixes:

src/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# Version information
7-
__version__ = "2.5.16"
7+
__version__ = "2.5.17"
88

99
import os
1010
import warnings

0 commit comments

Comments
 (0)