Skip to content

Commit a579137

Browse files
joyang-nvterrykong
andauthored
fix: can't find transformers_modules error for moonlight (#1124)
Signed-off-by: Jonas Yang <joyang@nvidia.com> Signed-off-by: Jonas Yang CN <joyang@nvidia.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Co-authored-by: Terry Kong <terryk@nvidia.com>
1 parent 38f0543 commit a579137

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

nemo_rl/distributed/worker_groups.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
)
2828
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
2929
from nemo_rl.distributed.worker_group_utils import recursive_merge_options
30-
from nemo_rl.utils.venvs import create_local_venv_on_each_node
30+
from nemo_rl.utils.venvs import (
31+
create_local_venv_on_each_node,
32+
patch_transformers_module_dir,
33+
)
3134

3235

3336
@dataclass
@@ -529,6 +532,7 @@ def _create_workers_from_bundle_indices(
529532
}
530533
runtime_env["env_vars"]["VIRTUAL_ENV"] = py_executable
531534
runtime_env["env_vars"]["UV_PROJECT_ENVIRONMENT"] = py_executable
535+
patch_transformers_module_dir(runtime_env["env_vars"])
532536

533537
extra_options = {"runtime_env": runtime_env, "name": name}
534538

nemo_rl/utils/venvs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,18 @@ def create_local_venv_on_each_node(py_executable: str, venv_name: str):
187187
ray.util.remove_placement_group(pg)
188188
# Return mapping from node IP to venv python path
189189
return paths[0]
190+
191+
192+
# Need to set PYTHONPATH to include transformers downloaded modules.
193+
# Assuming the cache directory is the same cross venvs.
194+
def patch_transformers_module_dir(env_vars: dict[str, str]):
195+
from transformers.utils.hub import TRANSFORMERS_CACHE
196+
197+
module_dir = os.path.join(TRANSFORMERS_CACHE, "..", "modules")
198+
assert module_dir is not None, "TRANSFORMERS_CACHE should exist."
199+
if "PYTHONPATH" not in env_vars:
200+
env_vars["PYTHONPATH"] = module_dir
201+
else:
202+
env_vars["PYTHONPATH"] = f"{module_dir}:{env_vars['PYTHONPATH']}"
203+
204+
return env_vars

0 commit comments

Comments
 (0)