Skip to content

Commit 596eb8c

Browse files
committed
fix sys path
1 parent 3b480d6 commit 596eb8c

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

tests/utils/registry_test.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
import unittest
22

3-
from trinity.common.workflows import WORKFLOWS, Workflow
3+
import ray
44

55

6-
class TestRegistry(unittest.TestCase):
7-
def test_dynamic_import(self):
6+
class ImportUtils:
7+
def run(self):
8+
from trinity.common.workflows import WORKFLOWS, Workflow
9+
810
workflow_cls = WORKFLOWS.get("tests.utils.plugins.main.MainDummyWorkflow")
9-
self.assertTrue(issubclass(workflow_cls, Workflow))
11+
assert issubclass(workflow_cls, Workflow)
1012
workflow = workflow_cls(task=None, model=None)
1113
res = workflow.run()
12-
self.assertEqual(res[0], 0)
13-
self.assertEqual(res[1], "0")
14+
assert res[0] == 0
15+
assert res[1] == "0"
16+
17+
18+
class TestRegistry(unittest.TestCase):
19+
def setUp(self):
20+
ray.init(ignore_reinit_error=True)
21+
22+
def tearDown(self):
23+
ray.shutdown()
24+
25+
def test_dynamic_import(self):
26+
# test local import
27+
ImportUtils().run()
28+
# test remote import
29+
ray.get(ray.remote(ImportUtils).remote().run.remote())

trinity/cli/launcher.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,8 @@ def debug(
248248
load_plugins()
249249
config = load_config(config_path)
250250
config.check_and_update()
251+
sys.path.insert(0, os.getcwd())
251252
config.ray_namespace = DEBUG_NAMESPACE
252-
ray.init(
253-
namespace=config.ray_namespace,
254-
runtime_env={"env_vars": config.get_envs()},
255-
ignore_reinit_error=True,
256-
)
257253
from trinity.common.models import create_debug_inference_model
258254

259255
if module == "inference_model":

trinity/utils/registry.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import traceback
23
from typing import Any, Type
34

@@ -145,6 +146,11 @@ def _dynamic_import(self, module_path: str, class_name: str) -> Type:
145146
"""
146147
import importlib
147148

148-
module = importlib.import_module(module_path)
149+
try:
150+
module = importlib.import_module(module_path)
151+
except ImportError:
152+
self.logger.error(f"system path: {sys.path}")
153+
self.logger.error(f"Cannot import module {module_path}")
154+
raise ImportError(f"Cannot import module {module_path}")
149155
module_cls = getattr(module, class_name)
150156
return module_cls

0 commit comments

Comments
 (0)