@@ -41,39 +41,6 @@ def ensure_triton_compat():
4141import os
4242
4343
44- # Automatic bfloat16 SDPA fallback for GPUs that don't support it (e.g., GTX 970)
45- _BFLOAT16_SDPA_WORKS = None # None=untested, True=works, False=needs float16 fallback
46- _ORIGINAL_SDPA = torch .nn .functional .scaled_dot_product_attention
47-
48- def _safe_scaled_dot_product_attention (query , key , value , * args , ** kwargs ):
49- """SDPA wrapper with automatic bfloat16 -> float16 fallback for old GPUs."""
50- global _BFLOAT16_SDPA_WORKS
51-
52- original_dtype = query .dtype
53-
54- # Fast path: already know bfloat16 fails on this GPU
55- if original_dtype == torch .bfloat16 and _BFLOAT16_SDPA_WORKS is False :
56- out = _ORIGINAL_SDPA (query .half (), key .half (), value .half (), * args , ** kwargs )
57- return out .to (original_dtype )
58-
59- try :
60- out = _ORIGINAL_SDPA (query , key , value , * args , ** kwargs )
61- if _BFLOAT16_SDPA_WORKS is None and original_dtype == torch .bfloat16 :
62- _BFLOAT16_SDPA_WORKS = True
63- return out
64- except RuntimeError as e :
65- if "CUBLAS_STATUS_NOT_SUPPORTED" in str (e ) and original_dtype == torch .bfloat16 :
66- _BFLOAT16_SDPA_WORKS = False
67- print ("⚠️ [SeedVR2] GPU does not support bfloat16 SDPA, using float16 fallback. "
68- "Tiling artifacts or black frames may occur." )
69- out = _ORIGINAL_SDPA (query .half (), key .half (), value .half (), * args , ** kwargs )
70- return out .to (original_dtype )
71- raise
72-
73- # Apply SDPA patch at module load
74- torch .nn .functional .scaled_dot_product_attention = _safe_scaled_dot_product_attention
75-
76-
7744# Flash Attention & Triton Compatibility Layer
7845# 1. Flash Attention - speedup for attention operations
7946try :
@@ -236,6 +203,24 @@ def _check_conv3d_memory_bug():
236203 print (f"🔧 Conv3d workaround active: PyTorch { torch_ver } , cuDNN { cudnn_ver } (fixing VAE 3x memory bug)" )
237204
238205
206+ # Bfloat16 CUBLAS support
207+ def _probe_bfloat16_support () -> bool :
208+ if not torch .cuda .is_available ():
209+ return True
210+ try :
211+ a = torch .randn (8 , 8 , dtype = torch .bfloat16 , device = 'cuda:0' )
212+ _ = torch .matmul (a , a )
213+ del a
214+ return True
215+ except RuntimeError as e :
216+ if "CUBLAS_STATUS_NOT_SUPPORTED" in str (e ):
217+ return False
218+ raise
219+
220+ BFLOAT16_SUPPORTED = _probe_bfloat16_support ()
221+ COMPUTE_DTYPE = torch .bfloat16 if BFLOAT16_SUPPORTED else torch .float16
222+
223+
239224def call_rope_with_stability (method , * args , ** kwargs ):
240225 """
241226 Call RoPE method with stability fixes:
0 commit comments