Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ classifiers = [
description = "Control system agnostic framework for building Device support in Python that will work for both EPICS and Tango"
dependencies = [
"aioserial",
"fastapi[standard]",
"numpy",
"pydantic",
"pvi~=0.10.0",
Expand Down Expand Up @@ -43,6 +44,7 @@ dev = [
"types-mock",
"aioca",
"p4p",
"httpx",
]

[project.scripts]
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions src/fastcs/backends/rest/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastcs.backend import Backend
from fastcs.controller import Controller

from .rest import RestServer


class RestBackend(Backend):
def __init__(self, controller: Controller):
super().__init__(controller)

self._server = RestServer(self._mapping)

def _run(self):
self._server.run()
147 changes: 147 additions & 0 deletions src/fastcs/backends/rest/rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from collections.abc import Awaitable, Callable, Coroutine
from dataclasses import dataclass
from types import MethodType
from typing import Any

import uvicorn
from fastapi import FastAPI
from pydantic import create_model

from fastcs.attributes import AttrR, AttrRW, AttrW, T
from fastcs.controller import BaseController
from fastcs.mapping import Mapping


@dataclass
class RestServerOptions:
host: str = "localhost"
port: int = 8080
log_level: str = "info"


class RestServer:
def __init__(self, mapping: Mapping):
self._mapping = mapping
self._app = self._create_app()

def _create_app(self):
app = FastAPI()
_add_dev_attributes(app, self._mapping)
_add_dev_commands(app, self._mapping)

return app

def run(self, options: RestServerOptions | None = None) -> None:
if options is None:
options = RestServerOptions()

uvicorn.run(
self._app,
host=options.host,
port=options.port,
log_level=options.log_level,
)


def _put_request_body(attribute: AttrW[T]):
return create_model(
f"Put{str(attribute.datatype.dtype)}Value",
**{"value": (attribute.datatype.dtype, ...)}, # type: ignore
)


def _wrap_attr_put(
attribute: AttrW[T],
) -> Callable[[T], Coroutine[Any, Any, None]]:
async def attr_set(request):
await attribute.process(request.value)

# Fast api uses type annotations for validation, schema, conversions
attr_set.__annotations__["request"] = _put_request_body(attribute)

return attr_set


def _get_response_body(attribute: AttrR[T]):
return create_model(
f"Get{str(attribute.datatype.dtype)}Value",
**{"value": (attribute.datatype.dtype, ...)}, # type: ignore
)


def _wrap_attr_get(
attribute: AttrR[T],
) -> Callable[[], Coroutine[Any, Any, Any]]:
async def attr_get() -> Any: # Must be any as response_model is set
value = attribute.get() # type: ignore
return {"value": value}

return attr_get


def _add_dev_attributes(app: FastAPI, mapping: Mapping) -> None:
for single_mapping in mapping.get_controller_mappings():
path = single_mapping.controller.path

for attr_name, attribute in single_mapping.attributes.items():
attr_name = attr_name.title().replace("_", "")
d_attr_name = f"{'/'.join(path)}/{attr_name}" if path else attr_name

match attribute:
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
case AttrRW():
app.add_api_route(
f"/{d_attr_name}",
_wrap_attr_get(attribute),
methods=["GET"], # Idemponent and safe data retrieval,
status_code=200, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/GET
response_model=_get_response_body(attribute),
)
app.add_api_route(
f"/{d_attr_name}",
_wrap_attr_put(attribute),
methods=["PUT"], # Idempotent state change
status_code=204, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/PUT
)
case AttrR():
app.add_api_route(
f"/{d_attr_name}",
_wrap_attr_get(attribute),
methods=["GET"],
status_code=200,
response_model=_get_response_body(attribute),
)
case AttrW():
app.add_api_route(
f"/{d_attr_name}",
_wrap_attr_put(attribute),
methods=["PUT"],
status_code=204,
)


def _wrap_command(
method: Callable, controller: BaseController
) -> Callable[..., Awaitable[None]]:
async def command() -> None:
await getattr(controller, method.__name__)()

return command


def _add_dev_commands(app: FastAPI, mapping: Mapping) -> None:
for single_mapping in mapping.get_controller_mappings():
path = single_mapping.controller.path

for name, method in single_mapping.command_methods.items():
cmd_name = name.title().replace("_", "")
d_cmd_name = f"{'/'.join(path)}/{cmd_name}" if path else cmd_name
app.add_api_route(
f"/{d_cmd_name}",
_wrap_command(
method.fn,
single_mapping.controller,
),
methods=["PUT"],
status_code=204,
)
90 changes: 90 additions & 0 deletions tests/backends/rest/test_rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import copy
import re
from typing import Any

import pytest
from fastapi.testclient import TestClient

from fastcs.attributes import AttrR
from fastcs.backends.rest.backend import RestBackend
from fastcs.datatypes import Bool, Float, Int


def pascal_2_snake(input: list[str]) -> list[str]:
snake_list = copy.deepcopy(input)
snake_list[-1] = re.sub(r"(?<!^)(?=[A-Z])", "_", snake_list[-1]).lower()
return snake_list


class TestRestServer:
@pytest.fixture(scope="class", autouse=True)
def setup_class(self, assertable_controller):
self.controller = assertable_controller

@pytest.fixture(scope="class")
def client(self):
app = RestBackend(self.controller)._server._app
return TestClient(app)

@pytest.fixture(scope="class")
def client_read(self, client):
def _client_read(path: list[str], expected: Any):
route = "/" + "/".join(path)
with self.controller.assertPerformed(pascal_2_snake(path), "READ"):
response = client.get(route)
assert response.status_code == 200
assert response.json()["value"] == expected

return _client_read

@pytest.fixture(scope="class")
def client_write(self, client):
def _client_write(path: list[str], value: Any):
route = "/" + "/".join(path)
with self.controller.assertPerformed(pascal_2_snake(path), "WRITE"):
response = client.put(route, json={"value": value})
assert response.status_code == 204

return _client_write

@pytest.fixture(scope="class")
def client_exec(self, client):
def _client_exec(path: list[str]):
route = "/" + "/".join(path)
with self.controller.assertPerformed(pascal_2_snake(path), "EXECUTE"):
response = client.put(route)
assert response.status_code == 204

return _client_exec

def test_read_int(self, client_read):
client_read(["ReadInt"], AttrR(Int())._value)

def test_read_write_int(self, client_read, client_write):
client_read(["ReadWriteInt"], AttrR(Int())._value)
client_write(["ReadWriteInt"], AttrR(Int())._value)

def test_read_write_float(self, client_read, client_write):
client_read(["ReadWriteFloat"], AttrR(Float())._value)
client_write(["ReadWriteFloat"], AttrR(Float())._value)

def test_read_bool(self, client_read):
client_read(["ReadBool"], AttrR(Bool())._value)

def test_write_bool(self, client_write):
client_write(["WriteBool"], AttrR(Bool())._value)

# # We need to discuss enums
# def test_string_enum(self, client_read, client_write):

def test_big_enum(self, client_read):
client_read(["BigEnum"], AttrR(Int(), allowed_values=list(range(1, 18)))._value)

def test_go(self, client_exec):
client_exec(["Go"])

def test_read_child1(self, client_read):
client_read(["SubController01", "ReadInt"], AttrR(Int())._value)

def test_read_child2(self, client_read):
client_read(["SubController02", "ReadInt"], AttrR(Int())._value)
67 changes: 65 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import copy
import os
import random
import string
import subprocess
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from typing import Any, Literal

import pytest
from aioca import purge_channel_caches
from pytest_mock import MockerFixture

from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater
from fastcs.controller import Controller
from fastcs.controller import Controller, SubController
from fastcs.datatypes import Bool, Float, Int, String
from fastcs.mapping import Mapping
from fastcs.wrappers import command, scan
Expand Down Expand Up @@ -49,7 +52,20 @@ class TestHandler(Handler, TestUpdater, TestSender):
pass


class TestSubController(SubController):
read_int: AttrR = AttrR(Int(), handler=TestUpdater())


class TestController(Controller):
def __init__(self) -> None:
super().__init__()

self._sub_controllers: list[TestSubController] = []
for index in range(1, 3):
controller = TestSubController()
self._sub_controllers.append(controller)
self.register_sub_controller(f"SubController{index:02d}", controller)

read_int: AttrR = AttrR(Int(), handler=TestUpdater())
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
read_write_float: AttrRW = AttrRW(Float())
Expand Down Expand Up @@ -80,11 +96,58 @@ async def counter(self):
self.count += 1


class AssertableController(TestController):
def __init__(self, mocker: MockerFixture) -> None:
super().__init__()
self.mocker = mocker

@contextmanager
def assertPerformed(
self, path: list[str], action: Literal["READ", "WRITE", "EXECUTE"]
):
queue = copy.deepcopy(path)
match action:
case "READ":
method = "get"
case "WRITE":
method = "process"
case "EXECUTE":
method = ""

# Navigate to subcontroller
controller = self
item_name = queue.pop(-1)
for item in queue:
controllers = controller.get_sub_controllers()
controller = controllers[item]

# create probe
if method:
attr = getattr(controller, item_name)
spy = self.mocker.spy(attr, method)
else:
spy = self.mocker.spy(controller, item_name)
initial = spy.call_count
try:
yield # Enter context
finally: # Exit context
final = spy.call_count
assert final == initial + 1, (
f"Expected {'.'.join(path + [method] if method else path)} "
f"to be called once, but it was called {final - initial} times."
)


@pytest.fixture
def controller():
return TestController()


@pytest.fixture(scope="class")
def assertable_controller(class_mocker: MockerFixture):
return AssertableController(class_mocker)


@pytest.fixture
def mapping(controller):
return Mapping(controller)
Expand Down