Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions src/fastcs/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from collections.abc import Callable, Coroutine, Sequence
from functools import partial
from pathlib import Path
from typing import Annotated, Any, Optional, TypeAlias, get_type_hints
from typing import Annotated, Any, Optional, get_type_hints

import typer
from IPython.terminal.embed import InteractiveShellEmbed
from pydantic import BaseModel, create_model
from pydantic import BaseModel, ValidationError, create_model
from ruamel.yaml import YAML

from fastcs import __version__
Expand All @@ -26,11 +26,6 @@
)
from fastcs.logging import logger as _fastcs_logger
from fastcs.tracer import Tracer
from fastcs.transport.epics.ca.transport import EpicsCATransport
from fastcs.transport.epics.pva.transport import EpicsPVATransport
from fastcs.transport.graphql.transport import GraphQLTransport
from fastcs.transport.rest.transport import RestTransport
from fastcs.transport.tango.transport import TangoTransport

from .attributes import ONCE, AttrR, AttrW
from .controller import BaseController, Controller
Expand All @@ -41,15 +36,6 @@
from .transport import Transport
from .util import validate_hinted_attributes

# Define a type alias for transport options
TransportList: TypeAlias = list[
EpicsPVATransport
| EpicsCATransport
| TangoTransport
| RestTransport
| GraphQLTransport
]

tracer = Tracer(name=__name__)
logger = _fastcs_logger.bind(logger_name=__name__)

Expand Down Expand Up @@ -440,8 +426,19 @@ def run(

yaml = YAML(typ="safe")
options_yaml = yaml.load(config)
# To do: Handle a k8s "values.yaml" file
instance_options = fastcs_options.model_validate(options_yaml)

try:
instance_options = fastcs_options.model_validate(options_yaml)
except ValidationError as e:
if any("transport" in error["loc"] for error in json.loads(e.json())):
raise LaunchError(
"Failed to validate transports. "
"Are the correct fastcs extras installed? "
f"Available transports:\n{Transport.subclasses}",
) from e

raise LaunchError("Failed to validate config") from e

if hasattr(instance_options, "controller"):
controller = controller_class(instance_options.controller)
else:
Expand All @@ -466,7 +463,7 @@ def _extract_options_model(controller_class: type[Controller]) -> type[BaseModel
if len(args) == 1:
fastcs_options = create_model(
f"{controller_class.__name__}",
transport=(TransportList, ...),
transport=(list[Transport.union()], ...),
__config__={"extra": "forbid"},
)
elif len(args) == 2:
Expand All @@ -483,7 +480,7 @@ def _extract_options_model(controller_class: type[Controller]) -> type[BaseModel
fastcs_options = create_model(
f"{controller_class.__name__}",
controller=(options_type, ...),
transport=(TransportList, ...),
transport=(list[Transport.union()], ...),
__config__={"extra": "forbid"},
)
else:
Expand Down
38 changes: 29 additions & 9 deletions src/fastcs/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
from .epics.ca.transport import EpicsCATransport as EpicsCATransport
from .epics.options import EpicsDocsOptions as EpicsDocsOptions
from .epics.options import EpicsGUIOptions as EpicsGUIOptions
from .epics.options import EpicsIOCOptions as EpicsIOCOptions
from .epics.pva.transport import EpicsPVATransport as EpicsPVATransport
from .graphql.transport import GraphQLTransport as GraphQLTransport
from .rest.transport import RestTransport as RestTransport
from .tango.options import TangoDSROptions as TangoDSROptions
from .tango.transport import TangoTransport as TangoTransport
from .transport import Transport as Transport

try:
from .epics.ca.transport import EpicsCATransport as EpicsCATransport
from .epics.options import EpicsDocsOptions as EpicsDocsOptions
from .epics.options import EpicsGUIOptions as EpicsGUIOptions
from .epics.options import EpicsIOCOptions as EpicsIOCOptions
except ImportError:
pass

try:
from .epics.pva.transport import EpicsPVATransport as EpicsPVATransport
except ImportError:
pass

try:
from .graphql.transport import GraphQLTransport as GraphQLTransport
except ImportError:
pass

try:
from .rest.transport import RestTransport as RestTransport
except ImportError:
pass

try:
from .tango.options import TangoDSROptions as TangoDSROptions
from .tango.transport import TangoTransport as TangoTransport
except ImportError:
pass
14 changes: 13 additions & 1 deletion src/fastcs/transport/transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any
from typing import Any, ClassVar, Union

from fastcs.controller_api import ControllerAPI

Expand All @@ -11,6 +11,18 @@ class Transport:
"""A base class for transport's implementation
so it can be used in FastCS."""

subclasses: ClassVar[list[type["Transport"]]] = []

def __init_subclass__(cls):
cls.subclasses.append(cls)

@classmethod
def union(cls):
if not cls.subclasses:
raise RuntimeError("No Transports found")

return Union[tuple(cls.subclasses)] # noqa: UP007

@abstractmethod
async def serve(self) -> None:
pass
Expand Down
8 changes: 4 additions & 4 deletions tests/data/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,19 @@
"items": {
"anyOf": [
{
"$ref": "#/$defs/EpicsPVATransport"
"$ref": "#/$defs/EpicsCATransport"
},
{
"$ref": "#/$defs/EpicsCATransport"
"$ref": "#/$defs/EpicsPVATransport"
},
{
"$ref": "#/$defs/TangoTransport"
"$ref": "#/$defs/GraphQLTransport"
},
{
"$ref": "#/$defs/RestTransport"
},
{
"$ref": "#/$defs/GraphQLTransport"
"$ref": "#/$defs/TangoTransport"
}
]
},
Expand Down
6 changes: 3 additions & 3 deletions tests/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from fastcs.exceptions import FastCSError, LaunchError
from fastcs.launch import (
FastCS,
TransportList,
_launch,
build_controller_api,
get_controller_schema,
launch,
)
from fastcs.transport.transport import Transport
from fastcs.wrappers import command, scan


Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, arg: SomeConfig, too_many):
def test_single_arg_schema():
target_model = create_model(
"SingleArg",
transport=(TransportList, ...),
transport=(list[Transport.union()], ...),
__config__={"extra": "forbid"},
)
target_dict = target_model.model_json_schema()
Expand All @@ -78,7 +78,7 @@ def test_is_hinted_schema(data):
target_model = create_model(
"IsHinted",
controller=(SomeConfig, ...),
transport=(TransportList, ...),
transport=(list[Transport.union()], ...),
__config__={"extra": "forbid"},
)
target_dict = target_model.model_json_schema()
Expand Down
Loading