Skip to content

Commit e7bfc55

Browse files
committed
Attempt graphql backend
1 parent 9d3d9bf commit e7bfc55

File tree

5 files changed

+368
-1
lines changed

5 files changed

+368
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"pvi~=0.10.0",
1919
"pytango",
2020
"softioc>=4.5.0",
21+
"strawberry-graphql[fastapi]",
2122
]
2223
dynamic = ["version"]
2324
license.file = "LICENSE"
@@ -63,7 +64,7 @@ version_file = "src/fastcs/_version.py"
6364

6465
[tool.pyright]
6566
typeCheckingMode = "standard"
66-
reportMissingImports = false # Ignore missing stubs in imported modules
67+
reportMissingImports = false # Ignore missing stubs in imported modules
6768

6869
[tool.pytest.ini_options]
6970
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error

src/fastcs/backends/graphQL/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastcs.backend import Backend
2+
from fastcs.controller import Controller
3+
4+
from .graphQL import GraphQLServer
5+
6+
7+
class GraphQLBackend(Backend):
8+
def __init__(self, controller: Controller):
9+
super().__init__(controller)
10+
11+
self._server = GraphQLServer(self._mapping)
12+
13+
def _run(self):
14+
self._server.run()
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from collections.abc import Awaitable, Callable, Coroutine
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
import strawberry
6+
import uvicorn
7+
from fastapi import FastAPI
8+
from strawberry.asgi import GraphQL
9+
from strawberry.tools import create_type
10+
from strawberry.types.field import StrawberryField
11+
12+
from fastcs.attributes import AttrR, AttrRW, AttrW, T
13+
from fastcs.controller import BaseController
14+
from fastcs.mapping import Mapping
15+
16+
17+
@dataclass
18+
class GraphQLServerOptions:
19+
host: str = "localhost"
20+
port: int = 8080
21+
log_level: str = "info"
22+
23+
24+
class GraphQLServer:
25+
def __init__(self, mapping: Mapping):
26+
self._mapping = mapping
27+
self._fields_tree: FieldsTree = FieldsTree("")
28+
self._app = self._create_app()
29+
30+
def _create_app(self) -> FastAPI:
31+
_add_dev_attributes(self._fields_tree, self._mapping)
32+
_add_dev_commands(self._fields_tree, self._mapping)
33+
34+
schema_kwargs = {}
35+
for key in ["query", "mutation"]:
36+
if s_type := self._fields_tree.create_type(key):
37+
schema_kwargs[key] = s_type
38+
schema = strawberry.Schema(**schema_kwargs) # type: ignore
39+
graphql_app: GraphQL = GraphQL(schema)
40+
41+
app = FastAPI()
42+
app.add_route("/graphql", graphql_app) # type: ignore
43+
app.add_websocket_route("/graphql", graphql_app) # type: ignore
44+
45+
return app
46+
47+
def run(self, options: GraphQLServerOptions | None = None) -> None:
48+
if options is None:
49+
options = GraphQLServerOptions()
50+
51+
uvicorn.run(
52+
self._app,
53+
host=options.host,
54+
port=options.port,
55+
log_level=options.log_level,
56+
)
57+
58+
59+
def _wrap_attr_set(
60+
d_attr_name: str,
61+
attribute: AttrW[T],
62+
) -> Callable[[T], Coroutine[Any, Any, None]]:
63+
async def _dynamic_f(value):
64+
await attribute.process(value)
65+
return value
66+
67+
# Add type annotations for validation, schema, conversions
68+
_dynamic_f.__name__ = d_attr_name
69+
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype
70+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
71+
72+
return _dynamic_f
73+
74+
75+
def _wrap_attr_get(
76+
d_attr_name: str,
77+
attribute: AttrR[T],
78+
) -> Callable[[], Coroutine[Any, Any, Any]]:
79+
async def _dynamic_f() -> Any:
80+
return attribute.get()
81+
82+
_dynamic_f.__name__ = d_attr_name
83+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
84+
85+
return _dynamic_f
86+
87+
88+
def _wrap_as_field(
89+
field_name: str,
90+
strawberry_type: type,
91+
) -> StrawberryField:
92+
def _dynamic_field():
93+
return strawberry_type()
94+
95+
_dynamic_field.__name__ = field_name
96+
_dynamic_field.__annotations__["return"] = strawberry_type
97+
98+
return strawberry.field(_dynamic_field)
99+
100+
101+
class NodeNotFoundError(Exception):
102+
pass
103+
104+
105+
class FieldsTree:
106+
def __init__(self, name: str):
107+
self.name = name
108+
self.children: list[FieldsTree] = []
109+
self.fields_dict: dict[str, list[StrawberryField]] = {
110+
"query": [],
111+
"mutation": [],
112+
}
113+
114+
def insert(self, path: list[str]) -> "FieldsTree":
115+
# Create child if not exist
116+
name = path.pop(0)
117+
if self.is_child(name):
118+
child = self.get_child(name)
119+
else:
120+
child = FieldsTree(name)
121+
self.children.append(child)
122+
123+
# Recurse if needed
124+
if path:
125+
return child.insert(path) # type: ignore
126+
else:
127+
return child
128+
129+
def is_child(self, name: str) -> bool:
130+
for child in self.children:
131+
if child.name == name:
132+
return True
133+
return False
134+
135+
def get_child(self, name: str) -> "FieldsTree":
136+
for child in self.children:
137+
if child.name == name:
138+
return child
139+
raise NodeNotFoundError
140+
141+
def create_type(self, strawberry_type: str) -> type | None:
142+
for child in self.children:
143+
if new_type := child.create_type(strawberry_type):
144+
child_field = _wrap_as_field(
145+
child.name,
146+
new_type,
147+
)
148+
self.fields_dict[strawberry_type].append(child_field)
149+
150+
if self.fields_dict[strawberry_type]:
151+
return create_type(
152+
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type]
153+
)
154+
else:
155+
return None
156+
157+
158+
def _add_dev_attributes(
159+
fields_tree: FieldsTree,
160+
mapping: Mapping,
161+
) -> None:
162+
for single_mapping in mapping.get_controller_mappings():
163+
path = single_mapping.controller.path
164+
if path:
165+
node = fields_tree.insert(path)
166+
else:
167+
node = fields_tree
168+
169+
if node is not None:
170+
for attr_name, attribute in single_mapping.attributes.items():
171+
attr_name = attr_name.title().replace("_", "")
172+
173+
match attribute:
174+
# mutation for server changes https://graphql.org/learn/queries/
175+
case AttrRW():
176+
node.fields_dict["query"].append(
177+
strawberry.field(_wrap_attr_get(attr_name, attribute))
178+
)
179+
node.fields_dict["mutation"].append(
180+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
181+
)
182+
case AttrR():
183+
node.fields_dict["query"].append(
184+
strawberry.field(_wrap_attr_get(attr_name, attribute))
185+
)
186+
case AttrW():
187+
node.fields_dict["mutation"].append(
188+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
189+
)
190+
191+
192+
def _wrap_command(
193+
method_name: str, method: Callable, controller: BaseController
194+
) -> Callable[..., Awaitable[bool]]:
195+
async def _dynamic_f() -> bool:
196+
await getattr(controller, method.__name__)()
197+
return True
198+
199+
_dynamic_f.__name__ = method_name
200+
201+
return _dynamic_f
202+
203+
204+
def _add_dev_commands(
205+
fields_tree: FieldsTree,
206+
mapping: Mapping,
207+
) -> None:
208+
for single_mapping in mapping.get_controller_mappings():
209+
path = single_mapping.controller.path
210+
if path:
211+
node = fields_tree.insert(path)
212+
else:
213+
node = fields_tree
214+
215+
if node is not None:
216+
for name, method in single_mapping.command_methods.items():
217+
cmd_name = name.title().replace("_", "")
218+
node.fields_dict["mutation"].append(
219+
strawberry.mutation(
220+
_wrap_command(
221+
cmd_name,
222+
method.fn,
223+
single_mapping.controller,
224+
)
225+
)
226+
)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import copy
2+
import json
3+
import re
4+
from typing import Any
5+
6+
import pytest
7+
from fastapi.testclient import TestClient
8+
9+
from fastcs.attributes import AttrR
10+
from fastcs.backends.graphQL.backend import GraphQLBackend
11+
from fastcs.datatypes import Bool, Float, Int
12+
13+
14+
def pascal_2_snake(input: list[str]) -> list[str]:
15+
snake_list = copy.deepcopy(input)
16+
snake_list[-1] = re.sub(r"(?<!^)(?=[A-Z])", "_", snake_list[-1]).lower()
17+
return snake_list
18+
19+
20+
def nest_query(path: list[str]) -> str:
21+
queue = copy.deepcopy(path)
22+
field = queue.pop(0)
23+
24+
if queue:
25+
nesting = nest_query(queue)
26+
return f"{field} {{ {nesting} }} "
27+
else:
28+
return field
29+
30+
31+
def nest_mutation(path: list[str], value: Any) -> str:
32+
queue = copy.deepcopy(path)
33+
field = queue.pop(0)
34+
35+
if queue:
36+
nesting = nest_query(queue)
37+
return f"{field} {{ {nesting} }} "
38+
else:
39+
return f"{field}(value: {json.dumps(value)})"
40+
41+
42+
def nest_responce(path: list[str], value: Any) -> dict:
43+
queue = copy.deepcopy(path)
44+
field = queue.pop(0)
45+
46+
if queue:
47+
nesting = nest_responce(queue, value)
48+
return {field: nesting}
49+
else:
50+
return {field: value}
51+
52+
53+
class TestGraphQLServer:
54+
@pytest.fixture(scope="class", autouse=True)
55+
def setup_class(self, assertable_controller):
56+
self.controller = assertable_controller
57+
58+
@pytest.fixture(scope="class")
59+
def client(self):
60+
app = GraphQLBackend(self.controller)._server._app
61+
return TestClient(app)
62+
63+
@pytest.fixture(scope="class")
64+
def client_read(self, client):
65+
def _client_read(path: list[str], expected: Any):
66+
query = f"query {{ {nest_query(path)} }}"
67+
with self.controller.assertPerformed(pascal_2_snake(path), "READ"):
68+
response = client.post("/graphql", json={"query": query})
69+
assert response.status_code == 200
70+
assert response.json()["data"] == nest_responce(path, expected)
71+
72+
return _client_read
73+
74+
@pytest.fixture(scope="class")
75+
def client_write(self, client):
76+
def _client_write(path: list[str], value: Any):
77+
mutation = f"mutation {{ {nest_mutation(path, value)} }}"
78+
with self.controller.assertPerformed(pascal_2_snake(path), "WRITE"):
79+
response = client.post("/graphql", json={"query": mutation})
80+
assert response.status_code == 200
81+
assert response.json()["data"] == nest_responce(path, value)
82+
83+
return _client_write
84+
85+
@pytest.fixture(scope="class")
86+
def client_exec(self, client):
87+
def _client_exec(path: list[str]):
88+
mutation = f"mutation {{ {nest_query(path)} }}"
89+
with self.controller.assertPerformed(pascal_2_snake(path), "EXECUTE"):
90+
response = client.post("/graphql", json={"query": mutation})
91+
assert response.status_code == 200
92+
assert response.json()["data"] == {path[-1]: True}
93+
94+
return _client_exec
95+
96+
def test_read_int(self, client_read):
97+
client_read(["ReadInt"], AttrR(Int())._value)
98+
99+
def test_read_write_int(self, client_read, client_write):
100+
client_read(["ReadWriteInt"], AttrR(Int())._value)
101+
client_write(["ReadWriteInt"], AttrR(Int())._value)
102+
103+
def test_read_write_float(self, client_read, client_write):
104+
client_read(["ReadWriteFloat"], AttrR(Float())._value)
105+
client_write(["ReadWriteFloat"], AttrR(Float())._value)
106+
107+
def test_read_bool(self, client_read):
108+
client_read(["ReadBool"], AttrR(Bool())._value)
109+
110+
def test_write_bool(self, client_write):
111+
client_write(["WriteBool"], AttrR(Bool())._value)
112+
113+
# # We need to discuss enums
114+
# def test_string_enum(self, client_read, client_write):
115+
116+
def test_big_enum(self, client_read):
117+
client_read(["BigEnum"], AttrR(Int(), allowed_values=list(range(1, 18)))._value)
118+
119+
def test_go(self, client_exec):
120+
client_exec(["Go"])
121+
122+
def test_read_child1(self, client_read):
123+
client_read(["SubController01", "ReadInt"], AttrR(Int())._value)
124+
125+
def test_read_child2(self, client_read):
126+
client_read(["SubController02", "ReadInt"], AttrR(Int())._value)

0 commit comments

Comments
 (0)