Skip to content

Commit 598fe95

Browse files
committed
feat(launcher): update _plugins.py
1 parent 2d1e5ea commit 598fe95

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

src/ansys/tools/common/launcher/_plugins.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,86 @@
2424

2525
import importlib.metadata
2626
from importlib.metadata import entry_points
27-
from typing import Any
27+
from functools import lru_cache
28+
from typing import Any, Type
2829

2930
from .interface import FALLBACK_LAUNCH_MODE_NAME, DataclassProtocol, LauncherProtocol
3031

3132
LAUNCHER_ENTRY_POINT = "ansys.tools.local_product_launcher.launcher"
3233

3334

3435
def get_launcher(*, product_name: str, launch_mode: str) -> type[LauncherProtocol[DataclassProtocol]]:
35-
"""Get the launcher plugin for a given product and launch mode."""
36+
"""Get the launcher plugin class for a given product and launch mode."""
3637
ep_name = f"{product_name}.{launch_mode}"
3738
for entrypoint in _get_entry_points():
3839
if entrypoint.name == ep_name:
3940
return entrypoint.load() # type: ignore
40-
else:
41-
raise KeyError(f"No plugin found for '{ep_name}'.")
41+
raise KeyError(f"No plugin found for '{ep_name}'.")
4242

4343

4444
def get_config_model(*, product_name: str, launch_mode: str) -> type[DataclassProtocol]:
45-
"""Get the configuration model for a given product and launch mode."""
45+
"""Get the configuration model class for a given product and launch mode."""
4646
return get_launcher(product_name=product_name, launch_mode=launch_mode).CONFIG_MODEL
4747

4848

49-
def get_all_plugins(hide_fallback: bool = True) -> dict[str, dict[str, LauncherProtocol[Any]]]:
50-
"""Get mapping {"<product_name>": {"<launch_mode>": Launcher}} containing all plugins."""
51-
res: dict[str, dict[str, LauncherProtocol[Any]]] = dict()
49+
def get_all_plugins(hide_fallback: bool = True) -> dict[str, dict[str, type[LauncherProtocol[Any]]]]:
50+
"""Get mapping {"<product_name>": {"<launch_mode>": LauncherClass}} containing all launcher plugins.
51+
52+
Parameters
53+
----------
54+
hide_fallback : bool, default=True
55+
If True, skip launch modes marked as fallback.
56+
57+
Returns
58+
-------
59+
dict[str, dict[str, type[LauncherProtocol[Any]]]]
60+
Mapping of product names to launch mode to launcher classes.
61+
"""
62+
res: dict[str, dict[str, type[LauncherProtocol[Any]]]] = dict()
5263
for entry_point in _get_entry_points():
53-
product_name, launch_mode = entry_point.name.split(".")
64+
try:
65+
product_name, launch_mode = entry_point.name.split(".")
66+
except ValueError:
67+
# skip malformed entry point names
68+
continue
69+
5470
if hide_fallback and launch_mode == FALLBACK_LAUNCH_MODE_NAME:
5571
continue
72+
73+
try:
74+
launcher_class = entry_point.load() # type: ignore
75+
except Exception:
76+
# skip broken plugins
77+
continue
78+
5679
res.setdefault(product_name, dict())
57-
res[product_name][launch_mode] = entry_point.load()
80+
res[product_name][launch_mode] = launcher_class
81+
5882
return res
5983

6084

6185
def has_fallback(product_name: str) -> bool:
62-
"""Return ``True`` if the given product has a fallback launcher."""
86+
"""Return True if the given product has a fallback launcher."""
6387
for entry_point in _get_entry_points():
64-
ep_product_name, ep_launch_mode = entry_point.name.split(".")
88+
try:
89+
ep_product_name, ep_launch_mode = entry_point.name.split(".")
90+
except ValueError:
91+
continue
6592
if product_name == ep_product_name and ep_launch_mode == FALLBACK_LAUNCH_MODE_NAME:
6693
return True
6794
return False
6895

6996

7097
def get_fallback_launcher(product_name: str) -> type[LauncherProtocol[DataclassProtocol]]:
71-
"""Get the fallback launcher for a given product."""
98+
"""Get the fallback launcher plugin class for a given product."""
7299
ep_name = f"{product_name}.{FALLBACK_LAUNCH_MODE_NAME}"
73100
for entrypoint in _get_entry_points():
74101
if entrypoint.name == ep_name:
75102
return entrypoint.load() # type: ignore
76-
else:
77-
raise KeyError(f"No fallback plugin found for '{product_name}'.")
103+
raise KeyError(f"No plugin found for '{ep_name}'.")
78104

79105

106+
@lru_cache
80107
def _get_entry_points() -> tuple[importlib.metadata.EntryPoint, ...]:
81108
"""Get all Local Product Launcher plugin entrypoints for launchers."""
82109
try:

0 commit comments

Comments
 (0)