Skip to content

Commit 3adf849

Browse files
committed
Attempt graphql backend
1 parent bdab4fa commit 3adf849

File tree

6 files changed

+345
-2
lines changed

6 files changed

+345
-2
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dependencies = [
1717
"pvi~=0.10.0",
1818
"pytango",
1919
"softioc",
20+
"fastapi[standard]",
21+
"strawberry-graphql[fastapi]",
2022
]
2123
dynamic = ["version"]
2224
license.file = "LICENSE"
@@ -43,6 +45,7 @@ dev = [
4345
"types-mock",
4446
"aioca",
4547
"p4p",
48+
"httpx",
4649
]
4750

4851
[project.scripts]
@@ -61,7 +64,7 @@ version_file = "src/fastcs/_version.py"
6164

6265
[tool.pyright]
6366
typeCheckingMode = "standard"
64-
reportMissingImports = false # Ignore missing stubs in imported modules
67+
reportMissingImports = false # Ignore missing stubs in imported modules
6568

6669
[tool.pytest.ini_options]
6770
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error

src/fastcs/backends/graphQL/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastcs.backend import Backend
2+
from fastcs.controller import Controller
3+
4+
from .graphQL import GraphQLServer
5+
6+
7+
class GraphQLBackend(Backend):
8+
def __init__(self, controller: Controller):
9+
super().__init__(controller)
10+
11+
self._server = GraphQLServer(self._mapping)
12+
13+
def _run(self):
14+
self._server.run()
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
from collections.abc import Awaitable, Callable, Coroutine
2+
from dataclasses import dataclass
3+
from types import MethodType
4+
from typing import Any
5+
6+
import strawberry
7+
import uvicorn
8+
from fastapi import FastAPI
9+
from strawberry.asgi import GraphQL
10+
from strawberry.tools import create_type
11+
from strawberry.types.field import StrawberryField
12+
13+
from fastcs.attributes import AttrR, AttrRW, AttrW, T
14+
from fastcs.controller import BaseController
15+
from fastcs.mapping import Mapping
16+
17+
18+
@dataclass
19+
class GraphQLServerOptions:
20+
host: str = "localhost"
21+
port: int = 8080
22+
log_level: str = "info"
23+
24+
25+
class GraphQLServer:
26+
def __init__(self, mapping: Mapping):
27+
self._mapping = mapping
28+
self._fields_tree: FieldsTree = FieldsTree("")
29+
self._app = self._create_app()
30+
31+
def _create_app(self) -> FastAPI:
32+
_add_dev_attributes(self._fields_tree, self._mapping)
33+
_add_dev_commands(self._fields_tree, self._mapping)
34+
35+
schema_kwargs = {}
36+
for key in ["query", "mutation"]:
37+
if s_type := self._fields_tree.create_type(key):
38+
schema_kwargs[key] = s_type
39+
schema = strawberry.Schema(**schema_kwargs) # type: ignore
40+
graphql_app: GraphQL = GraphQL(schema)
41+
42+
app = FastAPI()
43+
app.add_route("/graphql", graphql_app) # type: ignore
44+
app.add_websocket_route("/graphql", graphql_app) # type: ignore
45+
46+
return app
47+
48+
def run(self, options: GraphQLServerOptions | None = None) -> None:
49+
if options is None:
50+
options = GraphQLServerOptions()
51+
52+
uvicorn.run(
53+
self._app,
54+
host=options.host,
55+
port=options.port,
56+
log_level=options.log_level,
57+
)
58+
59+
60+
def _wrap_attr_set(
61+
d_attr_name: str,
62+
attribute: AttrW[T],
63+
) -> Callable[[T], Coroutine[Any, Any, None]]:
64+
async def _dynamic_f(value):
65+
await attribute.process_without_display_update(value)
66+
return value
67+
68+
# Add type annotations for validation, schema, conversions
69+
_dynamic_f.__name__ = d_attr_name
70+
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype
71+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
72+
73+
return _dynamic_f
74+
75+
76+
def _wrap_attr_get(
77+
d_attr_name: str,
78+
attribute: AttrR[T],
79+
) -> Callable[[], Coroutine[Any, Any, Any]]:
80+
async def _dynamic_f() -> Any:
81+
return attribute.get()
82+
83+
_dynamic_f.__name__ = d_attr_name
84+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
85+
86+
return _dynamic_f
87+
88+
89+
def _wrap_as_field(
90+
field_name: str,
91+
strawberry_type: type,
92+
) -> StrawberryField:
93+
def _dynamic_field():
94+
return strawberry_type()
95+
96+
_dynamic_field.__name__ = field_name
97+
_dynamic_field.__annotations__["return"] = strawberry_type
98+
99+
return strawberry.field(_dynamic_field)
100+
101+
102+
class NodeNotFoundError(Exception):
103+
pass
104+
105+
106+
class FieldsTree:
107+
def __init__(self, name: str):
108+
self.name = name
109+
self.children: list[FieldsTree] = []
110+
self.fields_dict: dict[str, list[StrawberryField]] = {
111+
"query": [],
112+
"mutation": [],
113+
}
114+
115+
def insert(self, path: list[str]) -> "FieldsTree":
116+
# Create child if not exist
117+
name = path.pop(0)
118+
if self.is_child(name):
119+
child = self.get_child(name)
120+
else:
121+
child = FieldsTree(name)
122+
self.children.append(child)
123+
124+
# Recurse if needed
125+
if path:
126+
return child.insert(path) # type: ignore
127+
else:
128+
return child
129+
130+
def is_child(self, name: str) -> bool:
131+
for child in self.children:
132+
if child.name == name:
133+
return True
134+
return False
135+
136+
def get_child(self, name: str) -> "FieldsTree":
137+
for child in self.children:
138+
if child.name == name:
139+
return child
140+
raise NodeNotFoundError
141+
142+
def create_type(self, strawberry_type: str) -> type:
143+
for child in self.children:
144+
child_field = _wrap_as_field(
145+
child.name,
146+
child.create_type(strawberry_type),
147+
)
148+
self.fields_dict[strawberry_type].append(child_field)
149+
150+
return create_type(
151+
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type]
152+
)
153+
154+
155+
def _add_dev_attributes(
156+
fields_tree: FieldsTree,
157+
mapping: Mapping,
158+
) -> None:
159+
for single_mapping in mapping.get_controller_mappings():
160+
path = single_mapping.controller.path
161+
if path:
162+
node = fields_tree.insert(path)
163+
else:
164+
node = fields_tree
165+
166+
if node is not None:
167+
for attr_name, attribute in single_mapping.attributes.items():
168+
attr_name = attr_name.title().replace("_", "")
169+
170+
match attribute:
171+
# mutation for server changes https://graphql.org/learn/queries/
172+
case AttrRW():
173+
node.fields_dict["query"].append(
174+
strawberry.field(_wrap_attr_get(attr_name, attribute))
175+
)
176+
node.fields_dict["mutation"].append(
177+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
178+
)
179+
case AttrR():
180+
node.fields_dict["query"].append(
181+
strawberry.field(_wrap_attr_get(attr_name, attribute))
182+
)
183+
case AttrW():
184+
node.fields_dict["mutation"].append(
185+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
186+
)
187+
188+
189+
def _wrap_command(
190+
method_name: str, method: Callable, controller: BaseController
191+
) -> Callable[..., Awaitable[bool]]:
192+
async def _dynamic_f() -> bool:
193+
await MethodType(method, controller)()
194+
return True
195+
196+
_dynamic_f.__name__ = method_name
197+
198+
return _dynamic_f
199+
200+
201+
def _add_dev_commands(
202+
fields_tree: FieldsTree,
203+
mapping: Mapping,
204+
) -> None:
205+
for single_mapping in mapping.get_controller_mappings():
206+
path = single_mapping.controller.path
207+
if path:
208+
node = fields_tree.insert(path)
209+
else:
210+
node = fields_tree
211+
212+
if node is not None:
213+
for name, method in single_mapping.command_methods.items():
214+
cmd_name = name.title().replace("_", "")
215+
node.fields_dict["mutation"].append(
216+
strawberry.mutation(
217+
_wrap_command(
218+
cmd_name,
219+
method.fn,
220+
single_mapping.controller,
221+
)
222+
)
223+
)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
3+
import pytest
4+
5+
from fastcs.attributes import AttrR
6+
from fastcs.datatypes import Bool, Float, Int
7+
8+
9+
class TestGraphQLServer:
10+
@pytest.fixture(autouse=True)
11+
def _server_client(self, graphQL_client):
12+
self._client = graphQL_client
13+
14+
def client_read(self, field: str, value):
15+
query = f"query {{{field}}}"
16+
response = self._client.post("/graphql", json={"query": query})
17+
assert response.status_code == 200
18+
assert response.json()["data"] == {field: value}
19+
20+
def client_write(self, field: str, value):
21+
mutation = f"mutation {{{field}(value: {json.dumps(value)})}}"
22+
response = self._client.post("/graphql", json={"query": mutation})
23+
assert response.status_code == 200
24+
assert response.json()["data"] == {field: value}
25+
26+
def client_exec(self, field: str):
27+
mutation = f"mutation {{{field}}}"
28+
response = self._client.post("/graphql", json={"query": mutation})
29+
assert response.status_code == 200
30+
assert response.json()["data"] == {field: True}
31+
32+
def test_read_int(self):
33+
self.client_read("ReadInt", AttrR(Int())._value)
34+
35+
def test_read_write_int(self):
36+
self.client_read("ReadWriteInt", AttrR(Int())._value)
37+
self.client_write("ReadWriteInt", AttrR(Int())._value)
38+
39+
def test_read_write_float(self):
40+
self.client_read("ReadWriteFloat", AttrR(Float())._value)
41+
self.client_write("ReadWriteFloat", AttrR(Float())._value)
42+
43+
def test_read_bool(self):
44+
self.client_read("ReadBool", AttrR(Bool())._value)
45+
46+
def test_write_bool(self):
47+
self.client_write("WriteBool", AttrR(Bool())._value)
48+
49+
# # We need to discuss enums
50+
# def test_string_enum(self):
51+
52+
def test_big_enum(self):
53+
self.client_read(
54+
"BigEnum", AttrR(Int(), allowed_values=list(range(1, 18)))._value
55+
)
56+
57+
def test_go(self):
58+
self.client_exec("Go")
59+
60+
def test_read_child1(self):
61+
query = "query { SubController01 { ReadWriteInt }}"
62+
response = self._client.post("/graphql", json={"query": query})
63+
assert response.status_code == 200
64+
assert response.json()["data"] == {
65+
"SubController01": {"ReadWriteInt": AttrR(Int())._value}
66+
}
67+
68+
def test_read_child2(self):
69+
query = "query { SubController02 { ReadWriteInt }}"
70+
response = self._client.post("/graphql", json={"query": query})
71+
assert response.status_code == 200
72+
assert response.json()["data"] == {
73+
"SubController02": {"ReadWriteInt": AttrR(Int())._value}
74+
}

tests/conftest.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
import pytest
1010
from aioca import purge_channel_caches
11+
from fastapi.testclient import TestClient
1112

1213
from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater
13-
from fastcs.controller import Controller
14+
from fastcs.backends.graphQL.backend import GraphQLBackend
15+
from fastcs.controller import Controller, SubController
1416
from fastcs.datatypes import Bool, Float, Int, String
1517
from fastcs.mapping import Mapping
1618
from fastcs.wrappers import command, scan
@@ -49,7 +51,27 @@ class TestHandler(Handler, TestUpdater, TestSender):
4951
pass
5052

5153

54+
class TestSubController(SubController):
55+
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
56+
57+
def __init__(self) -> None:
58+
super().__init__()
59+
60+
@command()
61+
async def go(self):
62+
pass
63+
64+
5265
class TestController(Controller):
66+
def __init__(self) -> None:
67+
super().__init__()
68+
69+
self._sub_controllers: list[TestSubController] = []
70+
for index in range(1, 3):
71+
controller = TestSubController()
72+
self._sub_controllers.append(controller)
73+
self.register_sub_controller(f"SubController{index:02d}", controller)
74+
5375
read_int: AttrR = AttrR(Int(), handler=TestUpdater())
5476
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
5577
read_write_float: AttrRW = AttrRW(Float())
@@ -120,3 +142,10 @@ def ioc():
120142
except ValueError:
121143
# Someone else already called communicate
122144
pass
145+
146+
147+
@pytest.fixture(scope="class")
148+
def graphQL_client():
149+
app = GraphQLBackend(TestController())._server._app
150+
client = TestClient(app)
151+
return client

0 commit comments

Comments
 (0)