Skip to content

Commit b98f07c

Browse files
committed
Attempt graphql backend
1 parent bdab4fa commit b98f07c

File tree

6 files changed

+419
-3
lines changed

6 files changed

+419
-3
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dependencies = [
1717
"pvi~=0.10.0",
1818
"pytango",
1919
"softioc",
20+
"fastapi[standard]",
21+
"strawberry-graphql[fastapi]",
2022
]
2123
dynamic = ["version"]
2224
license.file = "LICENSE"
@@ -43,6 +45,7 @@ dev = [
4345
"types-mock",
4446
"aioca",
4547
"p4p",
48+
"httpx",
4649
]
4750

4851
[project.scripts]
@@ -61,7 +64,7 @@ version_file = "src/fastcs/_version.py"
6164

6265
[tool.pyright]
6366
typeCheckingMode = "standard"
64-
reportMissingImports = false # Ignore missing stubs in imported modules
67+
reportMissingImports = false # Ignore missing stubs in imported modules
6568

6669
[tool.pytest.ini_options]
6770
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error

src/fastcs/backends/graphQL/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastcs.backend import Backend
2+
from fastcs.controller import Controller
3+
4+
from .graphQL import GraphQLServer
5+
6+
7+
class GraphQLBackend(Backend):
8+
def __init__(self, controller: Controller):
9+
super().__init__(controller)
10+
11+
self._server = GraphQLServer(self._mapping)
12+
13+
def _run(self):
14+
self._server.run()
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from collections.abc import Awaitable, Callable, Coroutine
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
import strawberry
6+
import uvicorn
7+
from fastapi import FastAPI
8+
from strawberry.asgi import GraphQL
9+
from strawberry.tools import create_type
10+
from strawberry.types.field import StrawberryField
11+
12+
from fastcs.attributes import AttrR, AttrRW, AttrW, T
13+
from fastcs.controller import BaseController
14+
from fastcs.mapping import Mapping
15+
16+
17+
@dataclass
18+
class GraphQLServerOptions:
19+
host: str = "localhost"
20+
port: int = 8080
21+
log_level: str = "info"
22+
23+
24+
class GraphQLServer:
25+
def __init__(self, mapping: Mapping):
26+
self._mapping = mapping
27+
self._fields_tree: FieldsTree = FieldsTree("")
28+
self._app = self._create_app()
29+
30+
def _create_app(self) -> FastAPI:
31+
_add_dev_attributes(self._fields_tree, self._mapping)
32+
_add_dev_commands(self._fields_tree, self._mapping)
33+
34+
schema_kwargs = {}
35+
for key in ["query", "mutation"]:
36+
if s_type := self._fields_tree.create_type(key):
37+
schema_kwargs[key] = s_type
38+
schema = strawberry.Schema(**schema_kwargs) # type: ignore
39+
graphql_app: GraphQL = GraphQL(schema)
40+
41+
app = FastAPI()
42+
app.add_route("/graphql", graphql_app) # type: ignore
43+
app.add_websocket_route("/graphql", graphql_app) # type: ignore
44+
45+
return app
46+
47+
def run(self, options: GraphQLServerOptions | None = None) -> None:
48+
if options is None:
49+
options = GraphQLServerOptions()
50+
51+
uvicorn.run(
52+
self._app,
53+
host=options.host,
54+
port=options.port,
55+
log_level=options.log_level,
56+
)
57+
58+
59+
def _wrap_attr_set(
60+
d_attr_name: str,
61+
attribute: AttrW[T],
62+
) -> Callable[[T], Coroutine[Any, Any, None]]:
63+
async def _dynamic_f(value):
64+
await attribute.process(value)
65+
return value
66+
67+
# Add type annotations for validation, schema, conversions
68+
_dynamic_f.__name__ = d_attr_name
69+
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype
70+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
71+
72+
return _dynamic_f
73+
74+
75+
def _wrap_attr_get(
76+
d_attr_name: str,
77+
attribute: AttrR[T],
78+
) -> Callable[[], Coroutine[Any, Any, Any]]:
79+
async def _dynamic_f() -> Any:
80+
return attribute.get()
81+
82+
_dynamic_f.__name__ = d_attr_name
83+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
84+
85+
return _dynamic_f
86+
87+
88+
def _wrap_as_field(
89+
field_name: str,
90+
strawberry_type: type,
91+
) -> StrawberryField:
92+
def _dynamic_field():
93+
return strawberry_type()
94+
95+
_dynamic_field.__name__ = field_name
96+
_dynamic_field.__annotations__["return"] = strawberry_type
97+
98+
return strawberry.field(_dynamic_field)
99+
100+
101+
class NodeNotFoundError(Exception):
102+
pass
103+
104+
105+
class FieldsTree:
106+
def __init__(self, name: str):
107+
self.name = name
108+
self.children: list[FieldsTree] = []
109+
self.fields_dict: dict[str, list[StrawberryField]] = {
110+
"query": [],
111+
"mutation": [],
112+
}
113+
114+
def insert(self, path: list[str]) -> "FieldsTree":
115+
# Create child if not exist
116+
name = path.pop(0)
117+
if self.is_child(name):
118+
child = self.get_child(name)
119+
else:
120+
child = FieldsTree(name)
121+
self.children.append(child)
122+
123+
# Recurse if needed
124+
if path:
125+
return child.insert(path) # type: ignore
126+
else:
127+
return child
128+
129+
def is_child(self, name: str) -> bool:
130+
for child in self.children:
131+
if child.name == name:
132+
return True
133+
return False
134+
135+
def get_child(self, name: str) -> "FieldsTree":
136+
for child in self.children:
137+
if child.name == name:
138+
return child
139+
raise NodeNotFoundError
140+
141+
def create_type(self, strawberry_type: str) -> type | None:
142+
for child in self.children:
143+
if new_type := child.create_type(strawberry_type):
144+
child_field = _wrap_as_field(
145+
child.name,
146+
new_type,
147+
)
148+
self.fields_dict[strawberry_type].append(child_field)
149+
150+
if self.fields_dict[strawberry_type]:
151+
return create_type(
152+
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type]
153+
)
154+
else:
155+
return None
156+
157+
158+
def _add_dev_attributes(
159+
fields_tree: FieldsTree,
160+
mapping: Mapping,
161+
) -> None:
162+
for single_mapping in mapping.get_controller_mappings():
163+
path = single_mapping.controller.path
164+
if path:
165+
node = fields_tree.insert(path)
166+
else:
167+
node = fields_tree
168+
169+
if node is not None:
170+
for attr_name, attribute in single_mapping.attributes.items():
171+
attr_name = attr_name.title().replace("_", "")
172+
173+
match attribute:
174+
# mutation for server changes https://graphql.org/learn/queries/
175+
case AttrRW():
176+
node.fields_dict["query"].append(
177+
strawberry.field(_wrap_attr_get(attr_name, attribute))
178+
)
179+
node.fields_dict["mutation"].append(
180+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
181+
)
182+
case AttrR():
183+
node.fields_dict["query"].append(
184+
strawberry.field(_wrap_attr_get(attr_name, attribute))
185+
)
186+
case AttrW():
187+
node.fields_dict["mutation"].append(
188+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
189+
)
190+
191+
192+
def _wrap_command(
193+
method_name: str, method: Callable, controller: BaseController
194+
) -> Callable[..., Awaitable[bool]]:
195+
async def _dynamic_f() -> bool:
196+
await method.__get__(controller)()
197+
return True
198+
199+
_dynamic_f.__name__ = method_name
200+
201+
return _dynamic_f
202+
203+
204+
def _add_dev_commands(
205+
fields_tree: FieldsTree,
206+
mapping: Mapping,
207+
) -> None:
208+
for single_mapping in mapping.get_controller_mappings():
209+
path = single_mapping.controller.path
210+
if path:
211+
node = fields_tree.insert(path)
212+
else:
213+
node = fields_tree
214+
215+
if node is not None:
216+
for name, method in single_mapping.command_methods.items():
217+
cmd_name = name.title().replace("_", "")
218+
node.fields_dict["mutation"].append(
219+
strawberry.mutation(
220+
_wrap_command(
221+
cmd_name,
222+
method.fn,
223+
single_mapping.controller,
224+
)
225+
)
226+
)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import copy
2+
import json
3+
import re
4+
from typing import Any
5+
6+
import pytest
7+
from fastapi.testclient import TestClient
8+
from pytest_mock import MockerFixture
9+
from tests.conftest import AssertableController
10+
11+
from fastcs.attributes import AttrR
12+
from fastcs.backends.graphQL.backend import GraphQLBackend
13+
from fastcs.datatypes import Bool, Float, Int
14+
15+
16+
def pascal_2_snake(input: list[str]) -> list[str]:
17+
snake_list = copy.deepcopy(input)
18+
snake_list[-1] = re.sub(r"(?<!^)(?=[A-Z])", "_", snake_list[-1]).lower()
19+
return snake_list
20+
21+
22+
def nest_query(path: list[str]) -> str:
23+
queue = copy.deepcopy(path)
24+
field = queue.pop(0)
25+
26+
if queue:
27+
nesting = nest_query(queue)
28+
return f"{field} {{ {nesting} }} "
29+
else:
30+
return field
31+
32+
33+
def nest_mutation(path: list[str], value: Any) -> str:
34+
queue = copy.deepcopy(path)
35+
field = queue.pop(0)
36+
37+
if queue:
38+
nesting = nest_query(queue)
39+
return f"{field} {{ {nesting} }} "
40+
else:
41+
return f"{field}(value: {json.dumps(value)})"
42+
43+
44+
def nest_responce(path: list[str], value: Any) -> dict:
45+
queue = copy.deepcopy(path)
46+
field = queue.pop(0)
47+
48+
if queue:
49+
nesting = nest_responce(queue, value)
50+
return {field: nesting}
51+
else:
52+
return {field: value}
53+
54+
55+
class TestGraphQLServer:
56+
@pytest.fixture(autouse=True)
57+
def setup_tests(self, mocker: MockerFixture):
58+
self.controller = AssertableController(mocker)
59+
app = GraphQLBackend(self.controller)._server._app
60+
self.client = TestClient(app)
61+
62+
def client_read(self, path: list[str], expected: Any):
63+
query = f"query {{ {nest_query(path)} }}"
64+
with self.controller.assertPerformed(pascal_2_snake(path), "READ"):
65+
response = self.client.post("/graphql", json={"query": query})
66+
assert response.status_code == 200
67+
assert response.json()["data"] == nest_responce(path, expected)
68+
69+
def client_write(self, path: list[str], value: Any):
70+
mutation = f"mutation {{ {nest_mutation(path, value)} }}"
71+
with self.controller.assertPerformed(pascal_2_snake(path), "WRITE"):
72+
response = self.client.post("/graphql", json={"query": mutation})
73+
assert response.status_code == 200
74+
assert response.json()["data"] == nest_responce(path, value)
75+
76+
def client_exec(self, path: list[str]):
77+
mutation = f"mutation {{ {nest_query(path)} }}"
78+
with self.controller.assertPerformed(pascal_2_snake(path), "EXECUTE"):
79+
response = self.client.post("/graphql", json={"query": mutation})
80+
assert response.status_code == 200
81+
assert response.json()["data"] == {path[-1]: True}
82+
83+
def test_read_int(self):
84+
self.client_read(["ReadInt"], AttrR(Int())._value)
85+
86+
def test_read_write_int(self):
87+
self.client_read(["ReadWriteInt"], AttrR(Int())._value)
88+
self.client_write(["ReadWriteInt"], AttrR(Int())._value)
89+
90+
def test_read_write_float(self):
91+
self.client_read(["ReadWriteFloat"], AttrR(Float())._value)
92+
self.client_write(["ReadWriteFloat"], AttrR(Float())._value)
93+
94+
def test_read_bool(self):
95+
self.client_read(["ReadBool"], AttrR(Bool())._value)
96+
97+
def test_write_bool(self):
98+
self.client_write(["WriteBool"], AttrR(Bool())._value)
99+
100+
# # We need to discuss enums
101+
# def test_string_enum(self):
102+
103+
def test_big_enum(self):
104+
self.client_read(
105+
["BigEnum"], AttrR(Int(), allowed_values=list(range(1, 18)))._value
106+
)
107+
108+
def test_go(self):
109+
self.client_exec(["Go"])
110+
111+
def test_read_child1(self):
112+
self.client_read(["SubController01", "ReadInt"], AttrR(Int())._value)
113+
114+
def test_read_child2(self):
115+
self.client_read(["SubController02", "ReadInt"], AttrR(Int())._value)

0 commit comments

Comments
 (0)