|
24 | 24 |
|
25 | 25 | import importlib.metadata |
26 | 26 | from importlib.metadata import entry_points |
27 | | -from typing import Any |
| 27 | +from functools import lru_cache |
| 28 | +from typing import Any, Type |
28 | 29 |
|
29 | 30 | from .interface import FALLBACK_LAUNCH_MODE_NAME, DataclassProtocol, LauncherProtocol |
30 | 31 |
|
31 | 32 | LAUNCHER_ENTRY_POINT = "ansys.tools.local_product_launcher.launcher" |
32 | 33 |
|
33 | 34 |
|
34 | 35 | def get_launcher(*, product_name: str, launch_mode: str) -> type[LauncherProtocol[DataclassProtocol]]: |
35 | | - """Get the launcher plugin for a given product and launch mode.""" |
| 36 | + """Get the launcher plugin class for a given product and launch mode.""" |
36 | 37 | ep_name = f"{product_name}.{launch_mode}" |
37 | 38 | for entrypoint in _get_entry_points(): |
38 | 39 | if entrypoint.name == ep_name: |
39 | 40 | return entrypoint.load() # type: ignore |
40 | | - else: |
41 | | - raise KeyError(f"No plugin found for '{ep_name}'.") |
| 41 | + raise KeyError(f"No plugin found for '{ep_name}'.") |
42 | 42 |
|
43 | 43 |
|
44 | 44 | def get_config_model(*, product_name: str, launch_mode: str) -> type[DataclassProtocol]: |
45 | | - """Get the configuration model for a given product and launch mode.""" |
| 45 | + """Get the configuration model class for a given product and launch mode.""" |
46 | 46 | return get_launcher(product_name=product_name, launch_mode=launch_mode).CONFIG_MODEL |
47 | 47 |
|
48 | 48 |
|
49 | | -def get_all_plugins(hide_fallback: bool = True) -> dict[str, dict[str, LauncherProtocol[Any]]]: |
50 | | - """Get mapping {"<product_name>": {"<launch_mode>": Launcher}} containing all plugins.""" |
51 | | - res: dict[str, dict[str, LauncherProtocol[Any]]] = dict() |
| 49 | +def get_all_plugins(hide_fallback: bool = True) -> dict[str, dict[str, type[LauncherProtocol[Any]]]]: |
| 50 | + """Get mapping {"<product_name>": {"<launch_mode>": LauncherClass}} containing all launcher plugins. |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | + hide_fallback : bool, default=True |
| 55 | + If True, skip launch modes marked as fallback. |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + dict[str, dict[str, type[LauncherProtocol[Any]]]] |
| 60 | + Mapping of product names to launch mode to launcher classes. |
| 61 | + """ |
| 62 | + res: dict[str, dict[str, type[LauncherProtocol[Any]]]] = dict() |
52 | 63 | for entry_point in _get_entry_points(): |
53 | | - product_name, launch_mode = entry_point.name.split(".") |
| 64 | + try: |
| 65 | + product_name, launch_mode = entry_point.name.split(".") |
| 66 | + except ValueError: |
| 67 | + # skip malformed entry point names |
| 68 | + continue |
| 69 | + |
54 | 70 | if hide_fallback and launch_mode == FALLBACK_LAUNCH_MODE_NAME: |
55 | 71 | continue |
| 72 | + |
| 73 | + try: |
| 74 | + launcher_class = entry_point.load() # type: ignore |
| 75 | + except Exception: |
| 76 | + # skip broken plugins |
| 77 | + continue |
| 78 | + |
56 | 79 | res.setdefault(product_name, dict()) |
57 | | - res[product_name][launch_mode] = entry_point.load() |
| 80 | + res[product_name][launch_mode] = launcher_class |
| 81 | + |
58 | 82 | return res |
59 | 83 |
|
60 | 84 |
|
61 | 85 | def has_fallback(product_name: str) -> bool: |
62 | | - """Return ``True`` if the given product has a fallback launcher.""" |
| 86 | + """Return True if the given product has a fallback launcher.""" |
63 | 87 | for entry_point in _get_entry_points(): |
64 | | - ep_product_name, ep_launch_mode = entry_point.name.split(".") |
| 88 | + try: |
| 89 | + ep_product_name, ep_launch_mode = entry_point.name.split(".") |
| 90 | + except ValueError: |
| 91 | + continue |
65 | 92 | if product_name == ep_product_name and ep_launch_mode == FALLBACK_LAUNCH_MODE_NAME: |
66 | 93 | return True |
67 | 94 | return False |
68 | 95 |
|
69 | 96 |
|
70 | 97 | def get_fallback_launcher(product_name: str) -> type[LauncherProtocol[DataclassProtocol]]: |
71 | | - """Get the fallback launcher for a given product.""" |
| 98 | + """Get the fallback launcher plugin class for a given product.""" |
72 | 99 | ep_name = f"{product_name}.{FALLBACK_LAUNCH_MODE_NAME}" |
73 | 100 | for entrypoint in _get_entry_points(): |
74 | 101 | if entrypoint.name == ep_name: |
75 | 102 | return entrypoint.load() # type: ignore |
76 | | - else: |
77 | | - raise KeyError(f"No fallback plugin found for '{product_name}'.") |
| 103 | + raise KeyError(f"No plugin found for '{ep_name}'.") |
78 | 104 |
|
79 | 105 |
|
| 106 | +@lru_cache |
80 | 107 | def _get_entry_points() -> tuple[importlib.metadata.EntryPoint, ...]: |
81 | 108 | """Get all Local Product Launcher plugin entrypoints for launchers.""" |
82 | 109 | try: |
|
0 commit comments