File tree Expand file tree Collapse file tree 1 file changed +18
-1
lines changed
Expand file tree Collapse file tree 1 file changed +18
-1
lines changed Original file line number Diff line number Diff line change @@ -289,6 +289,21 @@ def is_amd():
289289 return True
290290 return False
291291
292+ def amd_min_version (device = None , min_rdna_version = 0 ):
293+ if not is_amd ():
294+ return False
295+
296+ arch = torch .cuda .get_device_properties (device ).gcnArchName
297+ if arch .startswith ('gfx' ) and len (arch ) == 7 :
298+ try :
299+ cmp_rdna_version = int (arch [4 ]) + 2
300+ except :
301+ cmp_rdna_version = 0
302+ if cmp_rdna_version >= min_rdna_version :
303+ return True
304+
305+ return False
306+
292307MIN_WEIGHT_MEMORY_RATIO = 0.4
293308if is_nvidia ():
294309 MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -905,7 +920,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
905920
906921 # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
907922 # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
908- if d == torch .bfloat16 and (not is_amd ()) and should_use_bf16 (device ):
923+ # also a problem on RDNA4 except fp32 is also slow there.
924+ # This is due to large bf16 convolutions being extremely slow.
925+ if d == torch .bfloat16 and ((not is_amd ()) or amd_min_version (device , min_rdna_version = 4 )) and should_use_bf16 (device ):
909926 return d
910927
911928 return torch .float32
You can’t perform that action at this time.
0 commit comments