Skip to content

Commit 27a0fcc

Browse files
Enable bf16 VAE on RDNA4. (#9746)
1 parent ea6cdd2 commit 27a0fcc

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

comfy/model_management.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,21 @@ def is_amd():
289289
return True
290290
return False
291291

292+
def amd_min_version(device=None, min_rdna_version=0):
293+
if not is_amd():
294+
return False
295+
296+
arch = torch.cuda.get_device_properties(device).gcnArchName
297+
if arch.startswith('gfx') and len(arch) == 7:
298+
try:
299+
cmp_rdna_version = int(arch[4]) + 2
300+
except:
301+
cmp_rdna_version = 0
302+
if cmp_rdna_version >= min_rdna_version:
303+
return True
304+
305+
return False
306+
292307
MIN_WEIGHT_MEMORY_RATIO = 0.4
293308
if is_nvidia():
294309
MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -905,7 +920,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
905920

906921
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
907922
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
908-
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
923+
# also a problem on RDNA4 except fp32 is also slow there.
924+
# This is due to large bf16 convolutions being extremely slow.
925+
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
909926
return d
910927

911928
return torch.float32

0 commit comments

Comments
 (0)