File tree Expand file tree Collapse file tree 1 file changed +18
-1
lines changed Expand file tree Collapse file tree 1 file changed +18
-1
lines changed Original file line number Diff line number Diff line change @@ -289,6 +289,21 @@ def is_amd():
289
289
return True
290
290
return False
291
291
292
+ def amd_min_version (device = None , min_rdna_version = 0 ):
293
+ if not is_amd ():
294
+ return False
295
+
296
+ arch = torch .cuda .get_device_properties (device ).gcnArchName
297
+ if arch .startswith ('gfx' ) and len (arch ) == 7 :
298
+ try :
299
+ cmp_rdna_version = int (arch [4 ]) + 2
300
+ except :
301
+ cmp_rdna_version = 0
302
+ if cmp_rdna_version >= min_rdna_version :
303
+ return True
304
+
305
+ return False
306
+
292
307
MIN_WEIGHT_MEMORY_RATIO = 0.4
293
308
if is_nvidia ():
294
309
MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -905,7 +920,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
905
920
906
921
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
907
922
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
908
- if d == torch .bfloat16 and (not is_amd ()) and should_use_bf16 (device ):
923
+ # also a problem on RDNA4 except fp32 is also slow there.
924
+ # This is due to large bf16 convolutions being extremely slow.
925
+ if d == torch .bfloat16 and ((not is_amd ()) or amd_min_version (device , min_rdna_version = 4 )) and should_use_bf16 (device ):
909
926
return d
910
927
911
928
return torch .float32
You can’t perform that action at this time.
0 commit comments