|
22 | 22 | from comfy.cli_args import args, PerformanceFeature
|
23 | 23 | import torch
|
24 | 24 | import sys
|
| 25 | +import importlib |
25 | 26 | import platform
|
26 | 27 | import weakref
|
27 | 28 | import gc
|
@@ -289,6 +290,24 @@ def is_amd():
|
289 | 290 | return True
|
290 | 291 | return False
|
291 | 292 |
|
| 293 | +def amd_min_version(device=None, min_rdna_version=0): |
| 294 | + if not is_amd(): |
| 295 | + return False |
| 296 | + |
| 297 | + if is_device_cpu(device): |
| 298 | + return False |
| 299 | + |
| 300 | + arch = torch.cuda.get_device_properties(device).gcnArchName |
| 301 | + if arch.startswith('gfx') and len(arch) == 7: |
| 302 | + try: |
| 303 | + cmp_rdna_version = int(arch[4]) + 2 |
| 304 | + except: |
| 305 | + cmp_rdna_version = 0 |
| 306 | + if cmp_rdna_version >= min_rdna_version: |
| 307 | + return True |
| 308 | + |
| 309 | + return False |
| 310 | + |
292 | 311 | MIN_WEIGHT_MEMORY_RATIO = 0.4
|
293 | 312 | if is_nvidia():
|
294 | 313 | MIN_WEIGHT_MEMORY_RATIO = 0.0
|
@@ -321,12 +340,13 @@ def is_amd():
|
321 | 340 | logging.info("AMD arch: {}".format(arch))
|
322 | 341 | logging.info("ROCm version: {}".format(rocm_version))
|
323 | 342 | if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
324 |
| - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much |
325 |
| - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 |
326 |
| - ENABLE_PYTORCH_ATTENTION = True |
327 |
| -# if torch_version_numeric >= (2, 8): |
328 |
| -# if any((a in arch) for a in ["gfx1201"]): |
329 |
| -# ENABLE_PYTORCH_ATTENTION = True |
| 343 | + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. |
| 344 | + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much |
| 345 | + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 |
| 346 | + ENABLE_PYTORCH_ATTENTION = True |
| 347 | +# if torch_version_numeric >= (2, 8): |
| 348 | +# if any((a in arch) for a in ["gfx1201"]): |
| 349 | +# ENABLE_PYTORCH_ATTENTION = True |
330 | 350 | if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
331 | 351 | if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
332 | 352 | SUPPORT_FP8_OPS = True
|
@@ -905,7 +925,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
905 | 925 |
|
906 | 926 | # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
907 | 927 | # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
908 |
| - if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): |
| 928 | + # also a problem on RDNA4 except fp32 is also slow there. |
| 929 | + # This is due to large bf16 convolutions being extremely slow. |
| 930 | + if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device): |
909 | 931 | return d
|
910 | 932 |
|
911 | 933 | return torch.float32
|
|
0 commit comments