Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions python/triton_dist/amd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import json
import warnings
import re
import os
import sys
from threading import Lock
from hip import hip

Expand Down Expand Up @@ -206,7 +208,23 @@ def _get_gpu_uuid_by_physical_device_id(device_id: int):
_ensure_amdsmi_initialized()
devices = amdsmi.amdsmi_get_processor_handles()
handle = devices[device_id]
return amdsmi.amdsmi_get_gpu_device_uuid(handle)
major_version = int(torch.version.hip.split('.')[0])
if major_version >= 7:
# Due to a change in how UUIDs are generated for CPX mode, amdsmi no longer reports any uuid value that
# matches HIP/pytorch. HIP gets the value from sysfs, and we can also get the value there by getting
# the KFD info from amdsmi and then probing the sysfs directly.
kfd_info = amdsmi.amdsmi_get_gpu_kfd_info(handle)
node_id = kfd_info["node_id"]
kfd_path = os.path.join("/sys/devices/virtual/kfd/kfd/topology/nodes", str(node_id), "properties")
key = "unique_id"
with open(kfd_path, "r") as fd:
for line in fd:
if line.startswith(key):
uuid_str = line[len(key)+1:]
uuid_str = hex(int(uuid_str))
return uuid_str
else:
return amdsmi.amdsmi_get_gpu_device_uuid(handle)


def torch_uuid_to_unique_id(torch_uuid: str) -> str:
Expand All @@ -232,7 +250,8 @@ def get_uuid_by_physical_device_id(device_id: int | None = None):
try:
if has_amdsmi():
return _get_gpu_uuid_by_physical_device_id(device_id)
except Exception:
except Exception as e:
print(e, file=sys.stderr)
warnings.warn("get_uuid_by_physical_device_id failed with amdsmi, try using rocm-smi")

return _get_physical_gpu_uuid_rocm(device_id)
Expand Down
6 changes: 3 additions & 3 deletions tutorials/03a-inter-node-allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from triton_dist.language.extra.language_extra import __syncthreads, tid
from triton_dist.language.extra import libshmem_device
from triton_dist.profiler_utils import perf_func
from triton_dist.utils import finalize_distributed, initialize_distributed, rocshmem_barrier_all_on_stream, NVSHMEM_SIGNAL_DTYPE
from triton_dist.utils import finalize_distributed, initialize_distributed, rocshmem_barrier_all_on_stream, NVSHMEM_SIGNAL_DTYPE, sleep_async


@dataclass
Expand Down Expand Up @@ -139,7 +139,7 @@ def _run_all_gather_nccl():
print(f"✅ RANK[{RANK}] check passed")

# perf all-gather by NCCL
#sleep_async(1000) # in case CPU bound # Broken in rocm 7+
sleep_async(1000) # in case CPU bound # Broken in rocm 7+
_, duration_per_iter_ms = perf_func(
_run_all_gather_nccl,
warmup_iters=5,
Expand All @@ -153,7 +153,7 @@ def _run_all_gather_nccl():

# perf all-gather by triton-distributed
rocshmem_barrier_all_on_stream(torch.cuda.current_stream())
#sleep_async(1000) # in case CPU bound
sleep_async(1000) # in case CPU bound
_, duration_per_iter_ms = perf_func(
_run_all_gather_triton,
warmup_iters=5,
Expand Down