diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index b1a6c7194c..61e5665ae9 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -5,6 +5,7 @@ from typing import Type import ray +from parameterized import parameterized from tests.tools import TensorBoardParser, get_checkpoint_path, get_template_config from trinity.common.config import Config @@ -46,10 +47,17 @@ def run(self, workflow_cls=Type[Workflow]): class TestPluginLoader(unittest.TestCase): - def test_load_plugins_local(self): - my_workflow_cls = WORKFLOWS.get("my_workflow") - self.assertIsNone(my_workflow_cls) - os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins") + PLUGIN_DIR_PARAMS = [ + (str(Path(__file__).resolve().parent / "plugins"),), + (os.path.join("tests", "utils", "plugins"),), + ] + + @parameterized.expand(PLUGIN_DIR_PARAMS) + def test_load_plugins_local(self, plugin_dir): + if os.path.isabs(plugin_dir): + my_workflow_cls = WORKFLOWS.get("my_workflow") + self.assertIsNone(my_workflow_cls) + os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir try: load_plugins() except KeyError: @@ -63,8 +71,9 @@ def test_load_plugins_local(self): self.assertEqual(res[0], "Hello world") self.assertEqual(res[1], "Hi") - def test_load_plugins_remote(self): - os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins") + @parameterized.expand(PLUGIN_DIR_PARAMS) + def test_load_plugins_remote(self, plugin_dir): + os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir try: load_plugins() except KeyError: @@ -73,9 +82,7 @@ def test_load_plugins_remote(self): config = self.config ray.init( ignore_reinit_error=True, - runtime_env={ - "env_vars": {PLUGIN_DIRS_ENV_VAR: str(Path(__file__).resolve().parent / "plugins")} - }, + runtime_env={"env_vars": {PLUGIN_DIRS_ENV_VAR: plugin_dir}}, ) my_workflow_cls = WORKFLOWS.get("my_workflow") # disable plugin and use custom class from registry @@ -95,9 +102,10 @@ def test_load_plugins_remote(self): rollout_cnt = parser.metric_values("rollout") self.assertEqual(rollout_cnt, [2]) - def test_passing_custom_class(self): + @parameterized.expand(PLUGIN_DIR_PARAMS) + def test_passing_custom_class(self, plugin_dir): # disable plugin and pass custom class directly - os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins") + os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir try: load_plugins() except KeyError: diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index c3d956f2b1..47e5b42bbc 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -45,7 +45,7 @@ def load_plugin_from_dirs(plugin_dirs: Union[str, List[str]]) -> None: logger.info(f"Loading plugin modules from [{file}]...") # load modules from file try: - load_from_file(os.path.join(plugin_dir, file)) + load_from_file(str(file)) except Exception as e: logger.warning(f"Failed to load plugin module from [{file}]: {e}")