diff --git a/pyproject.toml b/pyproject.toml index b1c866980..2b7722cae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ classifiers = [ description = "Control system agnostic framework for building Device support in Python that will work for both EPICS and Tango" dependencies = [ "aioserial", + "fastapi[standard]", "numpy", "pydantic", "pvi~=0.10.0", @@ -43,6 +44,7 @@ dev = [ "types-mock", "aioca", "p4p", + "httpx", ] [project.scripts] diff --git a/src/fastcs/backends/rest/__init__.py b/src/fastcs/backends/rest/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/fastcs/backends/rest/backend.py b/src/fastcs/backends/rest/backend.py new file mode 100644 index 000000000..97b9d2322 --- /dev/null +++ b/src/fastcs/backends/rest/backend.py @@ -0,0 +1,14 @@ +from fastcs.backend import Backend +from fastcs.controller import Controller + +from .rest import RestServer + + +class RestBackend(Backend): + def __init__(self, controller: Controller): + super().__init__(controller) + + self._server = RestServer(self._mapping) + + def _run(self): + self._server.run() diff --git a/src/fastcs/backends/rest/rest.py b/src/fastcs/backends/rest/rest.py new file mode 100644 index 000000000..d41acdfde --- /dev/null +++ b/src/fastcs/backends/rest/rest.py @@ -0,0 +1,158 @@ +from collections.abc import Awaitable, Callable, Coroutine +from dataclasses import dataclass +from typing import Any + +import uvicorn +from fastapi import FastAPI +from pydantic import create_model + +from fastcs.attributes import AttrR, AttrRW, AttrW, T +from fastcs.controller import BaseController +from fastcs.mapping import Mapping + + +@dataclass +class RestServerOptions: + host: str = "localhost" + port: int = 8080 + log_level: str = "info" + + +class RestServer: + def __init__(self, mapping: Mapping): + self._mapping = mapping + self._app = self._create_app() + + def _create_app(self): + app = FastAPI() + _add_attribute_api_routes(app, self._mapping) + _add_command_api_routes(app, self._mapping) + + return app + + def run(self, options: RestServerOptions | None = None) -> None: + if options is None: + options = RestServerOptions() + + uvicorn.run( + self._app, + host=options.host, + port=options.port, + log_level=options.log_level, + ) + + +def _put_request_body(attribute: AttrW[T]): + """ + Creates a pydantic model for each datatype which defines the schema + of the PUT request body + """ + type_name = str(attribute.datatype.dtype.__name__).title() + # key=(type, ...) to declare a field without default value + return create_model( + f"Put{type_name}Value", + value=(attribute.datatype.dtype, ...), + ) + + +def _wrap_attr_put( + attribute: AttrW[T], +) -> Callable[[T], Coroutine[Any, Any, None]]: + async def attr_set(request): + await attribute.process(request.value) + + # Fast api uses type annotations for validation, schema, conversions + attr_set.__annotations__["request"] = _put_request_body(attribute) + + return attr_set + + +def _get_response_body(attribute: AttrR[T]): + """ + Creates a pydantic model for each datatype which defines the schema + of the GET request body + """ + type_name = str(attribute.datatype.dtype.__name__).title() + # key=(type, ...) to declare a field without default value + return create_model( + f"Get{type_name}Value", + value=(attribute.datatype.dtype, ...), + ) + + +def _wrap_attr_get( + attribute: AttrR[T], +) -> Callable[[], Coroutine[Any, Any, Any]]: + async def attr_get() -> Any: # Must be any as response_model is set + value = attribute.get() # type: ignore + return {"value": value} + + return attr_get + + +def _add_attribute_api_routes(app: FastAPI, mapping: Mapping) -> None: + for single_mapping in mapping.get_controller_mappings(): + path = single_mapping.controller.path + + for attr_name, attribute in single_mapping.attributes.items(): + attr_name = attr_name.replace("_", "-") + route = f"{'/'.join(path)}/{attr_name}" if path else attr_name + + match attribute: + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods + case AttrRW(): + app.add_api_route( + f"/{route}", + _wrap_attr_get(attribute), + methods=["GET"], # Idempotent and safe data retrieval, + status_code=200, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/GET + response_model=_get_response_body(attribute), + ) + app.add_api_route( + f"/{route}", + _wrap_attr_put(attribute), + methods=["PUT"], # Idempotent state change + status_code=204, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/PUT + ) + case AttrR(): + app.add_api_route( + f"/{route}", + _wrap_attr_get(attribute), + methods=["GET"], + status_code=200, + response_model=_get_response_body(attribute), + ) + case AttrW(): + app.add_api_route( + f"/{route}", + _wrap_attr_put(attribute), + methods=["PUT"], + status_code=204, + ) + + +def _wrap_command( + method: Callable, controller: BaseController +) -> Callable[..., Awaitable[None]]: + async def command() -> None: + await getattr(controller, method.__name__)() + + return command + + +def _add_command_api_routes(app: FastAPI, mapping: Mapping) -> None: + for single_mapping in mapping.get_controller_mappings(): + path = single_mapping.controller.path + + for name, method in single_mapping.command_methods.items(): + cmd_name = name.replace("_", "-") + route = f"/{'/'.join(path)}/{cmd_name}" if path else cmd_name + app.add_api_route( + f"/{route}", + _wrap_command( + method.fn, + single_mapping.controller, + ), + methods=["PUT"], + status_code=204, + ) diff --git a/tests/backends/rest/test_rest.py b/tests/backends/rest/test_rest.py new file mode 100644 index 000000000..cd506acd2 --- /dev/null +++ b/tests/backends/rest/test_rest.py @@ -0,0 +1,88 @@ +import pytest +from fastapi.testclient import TestClient + +from fastcs.backends.rest.backend import RestBackend + + +class TestRestServer: + @pytest.fixture(scope="class") + def client(self, assertable_controller): + app = RestBackend(assertable_controller)._server._app + return TestClient(app) + + def test_read_int(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["read_int"]): + response = client.get("/read-int") + assert response.status_code == 200 + assert response.json()["value"] == expect + + def test_read_write_int(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["read_write_int"]): + response = client.get("/read-write-int") + assert response.status_code == 200 + assert response.json()["value"] == expect + new = 9 + with assertable_controller.assert_write_here(["read_write_int"]): + response = client.put("/read-write-int", json={"value": new}) + assert client.get("/read-write-int").json()["value"] == new + + def test_read_write_float(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["read_write_float"]): + response = client.get("/read-write-float") + assert response.status_code == 200 + assert response.json()["value"] == expect + new = 0.5 + with assertable_controller.assert_write_here(["read_write_float"]): + response = client.put("/read-write-float", json={"value": new}) + assert client.get("/read-write-float").json()["value"] == new + + def test_read_bool(self, assertable_controller, client): + expect = False + with assertable_controller.assert_read_here(["read_bool"]): + response = client.get("/read-bool") + assert response.status_code == 200 + assert response.json()["value"] == expect + + def test_write_bool(self, assertable_controller, client): + with assertable_controller.assert_write_here(["write_bool"]): + client.put("/write-bool", json={"value": True}) + + def test_string_enum(self, assertable_controller, client): + expect = "" + with assertable_controller.assert_read_here(["string_enum"]): + response = client.get("/string-enum") + assert response.status_code == 200 + assert response.json()["value"] == expect + new = "new" + with assertable_controller.assert_write_here(["string_enum"]): + response = client.put("/string-enum", json={"value": new}) + assert client.get("/string-enum").json()["value"] == new + + def test_big_enum(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["big_enum"]): + response = client.get("/big-enum") + assert response.status_code == 200 + assert response.json()["value"] == expect + + def test_go(self, assertable_controller, client): + with assertable_controller.assert_execute_here(["go"]): + response = client.put("/go") + assert response.status_code == 204 + + def test_read_child1(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["SubController01", "read_int"]): + response = client.get("/SubController01/read-int") + assert response.status_code == 200 + assert response.json()["value"] == expect + + def test_read_child2(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["SubController02", "read_int"]): + response = client.get("/SubController02/read-int") + assert response.status_code == 200 + assert response.json()["value"] == expect