diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index 4e41f1b59..6866549d1 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -6,11 +6,11 @@ from collections.abc import Callable, Coroutine, Sequence from functools import partial from pathlib import Path -from typing import Annotated, Any, Optional, TypeAlias, get_type_hints +from typing import Annotated, Any, Optional, get_type_hints import typer from IPython.terminal.embed import InteractiveShellEmbed -from pydantic import BaseModel, create_model +from pydantic import BaseModel, ValidationError, create_model from ruamel.yaml import YAML from fastcs import __version__ @@ -26,11 +26,6 @@ ) from fastcs.logging import logger as _fastcs_logger from fastcs.tracer import Tracer -from fastcs.transport.epics.ca.transport import EpicsCATransport -from fastcs.transport.epics.pva.transport import EpicsPVATransport -from fastcs.transport.graphql.transport import GraphQLTransport -from fastcs.transport.rest.transport import RestTransport -from fastcs.transport.tango.transport import TangoTransport from .attributes import ONCE, AttrR, AttrW from .controller import BaseController, Controller @@ -41,15 +36,6 @@ from .transport import Transport from .util import validate_hinted_attributes -# Define a type alias for transport options -TransportList: TypeAlias = list[ - EpicsPVATransport - | EpicsCATransport - | TangoTransport - | RestTransport - | GraphQLTransport -] - tracer = Tracer(name=__name__) logger = _fastcs_logger.bind(logger_name=__name__) @@ -440,8 +426,19 @@ def run( yaml = YAML(typ="safe") options_yaml = yaml.load(config) - # To do: Handle a k8s "values.yaml" file - instance_options = fastcs_options.model_validate(options_yaml) + + try: + instance_options = fastcs_options.model_validate(options_yaml) + except ValidationError as e: + if any("transport" in error["loc"] for error in json.loads(e.json())): + raise LaunchError( + "Failed to validate transports. " + "Are the correct fastcs extras installed? " + f"Available transports:\n{Transport.subclasses}", + ) from e + + raise LaunchError("Failed to validate config") from e + if hasattr(instance_options, "controller"): controller = controller_class(instance_options.controller) else: @@ -466,7 +463,7 @@ def _extract_options_model(controller_class: type[Controller]) -> type[BaseModel if len(args) == 1: fastcs_options = create_model( f"{controller_class.__name__}", - transport=(TransportList, ...), + transport=(list[Transport.union()], ...), __config__={"extra": "forbid"}, ) elif len(args) == 2: @@ -483,7 +480,7 @@ def _extract_options_model(controller_class: type[Controller]) -> type[BaseModel fastcs_options = create_model( f"{controller_class.__name__}", controller=(options_type, ...), - transport=(TransportList, ...), + transport=(list[Transport.union()], ...), __config__={"extra": "forbid"}, ) else: diff --git a/src/fastcs/transport/__init__.py b/src/fastcs/transport/__init__.py index b79ba26b0..c5bc22b5d 100644 --- a/src/fastcs/transport/__init__.py +++ b/src/fastcs/transport/__init__.py @@ -1,10 +1,30 @@ -from .epics.ca.transport import EpicsCATransport as EpicsCATransport -from .epics.options import EpicsDocsOptions as EpicsDocsOptions -from .epics.options import EpicsGUIOptions as EpicsGUIOptions -from .epics.options import EpicsIOCOptions as EpicsIOCOptions -from .epics.pva.transport import EpicsPVATransport as EpicsPVATransport -from .graphql.transport import GraphQLTransport as GraphQLTransport -from .rest.transport import RestTransport as RestTransport -from .tango.options import TangoDSROptions as TangoDSROptions -from .tango.transport import TangoTransport as TangoTransport from .transport import Transport as Transport + +try: + from .epics.ca.transport import EpicsCATransport as EpicsCATransport + from .epics.options import EpicsDocsOptions as EpicsDocsOptions + from .epics.options import EpicsGUIOptions as EpicsGUIOptions + from .epics.options import EpicsIOCOptions as EpicsIOCOptions +except ImportError: + pass + +try: + from .epics.pva.transport import EpicsPVATransport as EpicsPVATransport +except ImportError: + pass + +try: + from .graphql.transport import GraphQLTransport as GraphQLTransport +except ImportError: + pass + +try: + from .rest.transport import RestTransport as RestTransport +except ImportError: + pass + +try: + from .tango.options import TangoDSROptions as TangoDSROptions + from .tango.transport import TangoTransport as TangoTransport +except ImportError: + pass diff --git a/src/fastcs/transport/transport.py b/src/fastcs/transport/transport.py index 6d1b306db..c1904c8a2 100644 --- a/src/fastcs/transport/transport.py +++ b/src/fastcs/transport/transport.py @@ -1,7 +1,7 @@ import asyncio from abc import abstractmethod from dataclasses import dataclass -from typing import Any +from typing import Any, ClassVar, Union from fastcs.controller_api import ControllerAPI @@ -11,6 +11,18 @@ class Transport: """A base class for transport's implementation so it can be used in FastCS.""" + subclasses: ClassVar[list[type["Transport"]]] = [] + + def __init_subclass__(cls): + cls.subclasses.append(cls) + + @classmethod + def union(cls): + if not cls.subclasses: + raise RuntimeError("No Transports found") + + return Union[tuple(cls.subclasses)] # noqa: UP007 + @abstractmethod async def serve(self) -> None: pass diff --git a/tests/data/schema.json b/tests/data/schema.json index 4f2d6f1ac..9af08266c 100644 --- a/tests/data/schema.json +++ b/tests/data/schema.json @@ -208,19 +208,19 @@ "items": { "anyOf": [ { - "$ref": "#/$defs/EpicsPVATransport" + "$ref": "#/$defs/EpicsCATransport" }, { - "$ref": "#/$defs/EpicsCATransport" + "$ref": "#/$defs/EpicsPVATransport" }, { - "$ref": "#/$defs/TangoTransport" + "$ref": "#/$defs/GraphQLTransport" }, { "$ref": "#/$defs/RestTransport" }, { - "$ref": "#/$defs/GraphQLTransport" + "$ref": "#/$defs/TangoTransport" } ] }, diff --git a/tests/test_launch.py b/tests/test_launch.py index 811e4ed40..1646f2115 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -19,12 +19,12 @@ from fastcs.exceptions import FastCSError, LaunchError from fastcs.launch import ( FastCS, - TransportList, _launch, build_controller_api, get_controller_schema, launch, ) +from fastcs.transport.transport import Transport from fastcs.wrappers import command, scan @@ -61,7 +61,7 @@ def __init__(self, arg: SomeConfig, too_many): def test_single_arg_schema(): target_model = create_model( "SingleArg", - transport=(TransportList, ...), + transport=(list[Transport.union()], ...), __config__={"extra": "forbid"}, ) target_dict = target_model.model_json_schema() @@ -78,7 +78,7 @@ def test_is_hinted_schema(data): target_model = create_model( "IsHinted", controller=(SomeConfig, ...), - transport=(TransportList, ...), + transport=(list[Transport.union()], ...), __config__={"extra": "forbid"}, ) target_dict = target_model.model_json_schema()