Skip to content

Commit 0f65e74

Browse files
authored
Improve rocm7 behavior (#58)
* Added support for handling the differences between the AWS OFI plugin for use on slingshot systems. * Update the version number. * Addressed reviewer feedback. * Added a note about forcing libfabric.
1 parent 4c97b24 commit 0f65e74

File tree

4 files changed

+42
-9
lines changed

4 files changed

+42
-9
lines changed

.github/workflows/release.sh

100644100755
File mode changed.

hpc_launcher/systems/lc/el_capitan_family.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from hpc_launcher.schedulers.flux import FluxScheduler
1616
from hpc_launcher.systems.system import System, SystemParams
1717
import os
18+
import re
1819

1920
import logging
2021

@@ -105,6 +106,17 @@ def environment_variables(self) -> list[tuple[str, str]]:
105106
if optimize_comm_protocol.upper() == "RCCL" or optimize_comm_protocol.upper() == "*CCL":
106107
optimize_rccl_protocol = True
107108

109+
aws_ofi_plugin = None
110+
different_ofi_plugin = os.getenv("LBANN_USE_THIS_OFI_PLUGIN")
111+
if different_ofi_plugin is not None:
112+
if os.path.isdir(different_ofi_plugin):
113+
env_list.append(
114+
("LD_LIBRARY_PATH", different_ofi_plugin + ":${LD_LIBRARY_PATH}")
115+
)
116+
aws_ofi_plugin = different_ofi_plugin
117+
else:
118+
logger.warn(f"WARNING: invalid path provided in LBANN_USE_THIS_OFI_PLUGIN: {different_ofi_plugin}. Ensure one is loaded or performance will be degraded.")
119+
108120
if os.getenv("ROCM_PATH") is not None:
109121
rocm_path = os.getenv("ROCM_PATH")
110122
env_list.append(
@@ -114,8 +126,9 @@ def environment_variables(self) -> list[tuple[str, str]]:
114126
+ ":${LD_LIBRARY_PATH}",
115127
)
116128
)
117-
if optimize_rccl_protocol:
118-
rocm_ver = os.path.basename(rocm_path)
129+
rocm_ver = os.path.basename(rocm_path)
130+
131+
if optimize_rccl_protocol and not aws_ofi_plugin:
119132
# Check for and include the AWS_OFI_PLUGIN if it exists
120133
sys_type = os.getenv("SYS_TYPE")
121134
aws_ofi_plugin = f'/collab/usr/global/tools/rccl/{sys_type}/{rocm_ver}/install/lib'
@@ -131,11 +144,23 @@ def environment_variables(self) -> list[tuple[str, str]]:
131144
else:
132145
logger.warn(f"WARNING: using RCCL communication protocol and no default AWS_OFI_RCCL plugin was detected. Checked {aws_ofi_plugin}. Ensure one is loaded or performance will be degraded.")
133146

134-
different_ofi_plugin = os.getenv("LBANN_USE_THIS_OFI_PLUGIN")
135-
if different_ofi_plugin is not None:
136-
env_list.append(
137-
("LD_LIBRARY_PATH", different_ofi_plugin + ":${LD_LIBRARY_PATH}")
138-
)
147+
match = re.match(r'rocm-(\d+)\.(\d+).(\d+)', rocm_ver)
148+
if match:
149+
rocm_major = int(match.group(1))
150+
rocm_minor = int(match.group(2))
151+
# rocm_patch = int(match.group(3))
152+
153+
# Unless overriden by an external env variable set the NCCL_NET to ensure that the libfabric interface is used, e.g.: libfabric, IB, Socket
154+
msg = "By default HPC-launcher will force slingshot systems to use the libfabric NCCL/RCCL plugin or fail. This behavior can be overridden by setting NCCL_NET=Socket in the calling environment."
155+
if rocm_major >= 7 and rocm_minor >= 1:
156+
# Add AWS_OFI_NCCL for ROCm 7.1 - Ensure that it pick up the correct library object
157+
if not os.getenv("NCCL_NET_PLUGIN"):
158+
env_list.append(("NCCL_NET_PLUGIN", "librccl-net.so"))
159+
if not os.getenv("NCCL_NET"):
160+
env_list.append(("NCCL_NET", "libfabric", msg))
161+
else:
162+
if not os.getenv("NCCL_NET"):
163+
env_list.append(("NCCL_NET", '\"AWS Libfabric\"', msg))
139164

140165
if optimize_rccl_protocol:
141166
# Performance tuning for HPE Slingshot Cassini NIC (Audited on 3/31/25) - Only use with RCCL
@@ -167,6 +192,10 @@ def environment_variables(self) -> list[tuple[str, str]]:
167192
# Improve the performance of large scale RCCL initialization - should only be used on wire-up
168193
env_list.append(("NCCL_SOCKET_IFNAME", "hsi0"))
169194

195+
# Ensure that PyTorch respects channel's last for MIOpen (Audited on 1/13/2026)
196+
env_list.append(("PYTORCH_MIOPEN_SUGGEST_NHWC", "1"))
197+
env_list.append(("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM", "1"))
198+
170199
for i in self._aux_env_list:
171200
env_list.append(i)
172201

hpc_launcher/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.3"
1+
__version__ = "1.0.4"

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ def get_rocm_version():
3232
if rocm_version:
3333
# Constrain ROCm-dependent packages
3434
major, minor, patch = rocm_version.split('.')
35-
extras.append(f"amdsmi=={major}.{minor}.{patch}")
35+
# Releases of AMDSMI in PyPI are lagging github releases
36+
if int(major) >= 7:
37+
extras.append(f"amdsmi>={major},<={major}.{minor}.{patch}")
38+
else:
39+
extras.append(f"amdsmi=={major}.{minor}.{patch}")
3640
else:
3741
# Fallback or raise error
3842
raise RuntimeError("ROCm installation not found!")

0 commit comments

Comments
 (0)