Skip to content

Commit beefad4

Browse files
committed
Attempt graphql backend
1 parent bdab4fa commit beefad4

File tree

5 files changed

+333
-2
lines changed

5 files changed

+333
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dev = [
4343
"types-mock",
4444
"aioca",
4545
"p4p",
46+
"strawberry-graphql[fastapi]",
4647
]
4748

4849
[project.scripts]
@@ -61,7 +62,7 @@ version_file = "src/fastcs/_version.py"
6162

6263
[tool.pyright]
6364
typeCheckingMode = "standard"
64-
reportMissingImports = false # Ignore missing stubs in imported modules
65+
reportMissingImports = false # Ignore missing stubs in imported modules
6566

6667
[tool.pytest.ini_options]
6768
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastcs.backend import Backend
2+
from fastcs.controller import Controller
3+
4+
from .graphQL import GraphQLServer
5+
6+
7+
class GraphQLBackend(Backend):
8+
def __init__(self, controller: Controller):
9+
super().__init__(controller)
10+
11+
self._server = GraphQLServer(self._mapping)
12+
13+
def _run(self):
14+
self._server.run()
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from collections.abc import Awaitable, Callable, Coroutine
2+
from dataclasses import dataclass
3+
from types import MethodType
4+
from typing import Any
5+
6+
import strawberry
7+
import uvicorn
8+
from fastapi import FastAPI
9+
from strawberry.asgi import GraphQL
10+
from strawberry.tools import create_type
11+
from strawberry.types.field import StrawberryField
12+
13+
from fastcs.attributes import AttrR, AttrRW, AttrW, T
14+
from fastcs.controller import BaseController
15+
from fastcs.mapping import Mapping
16+
17+
18+
@dataclass
19+
class GraphQLServerOptions:
20+
host: str = "localhost"
21+
port: int = 8080
22+
log_level: str = "info"
23+
24+
25+
class GraphQLServer:
26+
def __init__(self, mapping: Mapping):
27+
self._mapping = mapping
28+
self._fields_tree: FieldsTree = FieldsTree("")
29+
self._app = self._create_app()
30+
31+
def _create_app(self) -> FastAPI:
32+
_add_dev_attributes(self._fields_tree, self._mapping)
33+
_add_dev_commands(self._fields_tree, self._mapping)
34+
35+
schema_kwargs = {}
36+
for key in ["query", "mutation"]:
37+
if s_type := self._fields_tree.create_type(key):
38+
schema_kwargs[key] = s_type
39+
schema = strawberry.Schema(**schema_kwargs) # type: ignore
40+
graphql_app: GraphQL = GraphQL(schema)
41+
42+
app = FastAPI()
43+
app.add_route("/graphql", graphql_app)
44+
app.add_websocket_route("/graphql", graphql_app)
45+
46+
return app
47+
48+
def run(self, options: GraphQLServerOptions | None = None) -> None:
49+
if options is None:
50+
options = GraphQLServerOptions()
51+
52+
uvicorn.run(
53+
self._app,
54+
host=options.host,
55+
port=options.port,
56+
log_level=options.log_level,
57+
)
58+
59+
60+
def _wrap_attr_set(
61+
d_attr_name: str,
62+
attribute: AttrW[T],
63+
) -> Callable[[T], Coroutine[Any, Any, None]]:
64+
async def _dynamic_f(value):
65+
await attribute.process_without_display_update(value)
66+
return value
67+
68+
# Add type annotations for validation, schema, conversions
69+
_dynamic_f.__name__ = d_attr_name
70+
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype
71+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
72+
73+
return _dynamic_f
74+
75+
76+
def _wrap_attr_get(
77+
d_attr_name: str,
78+
attribute: AttrR[T],
79+
) -> Callable[[], Coroutine[Any, Any, Any]]:
80+
async def _dynamic_f() -> Any:
81+
return attribute.get()
82+
83+
_dynamic_f.__name__ = d_attr_name
84+
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype
85+
86+
return _dynamic_f
87+
88+
89+
def _wrap_as_field(
90+
field_name: str,
91+
strawberry_type: type,
92+
) -> StrawberryField:
93+
def _dynamic_field():
94+
return strawberry_type()
95+
96+
_dynamic_field.__name__ = field_name
97+
_dynamic_field.__annotations__["return"] = strawberry_type
98+
99+
return strawberry.field(_dynamic_field)
100+
101+
102+
class FieldsTree:
103+
def __init__(self, name: str):
104+
self.name = name
105+
self.children: list[FieldsTree] = []
106+
self.fields_dict: dict[str, list[StrawberryField]] = {
107+
"query": [],
108+
"mutation": [],
109+
}
110+
111+
def insert(self, path: list[str]):
112+
# Create child if not exist
113+
name = path.pop(0)
114+
if not (child := self.get_child(name)):
115+
child = FieldsTree(name)
116+
self.children.append(child)
117+
else:
118+
child = self.get_child(name)
119+
120+
# Recurse if needed
121+
if path:
122+
return child.insert(path) # type: ignore
123+
else:
124+
return child
125+
126+
def get_child(self, name: str):
127+
for child in self.children:
128+
if child.name == name:
129+
return child
130+
return None
131+
132+
def create_type(self, strawberry_type: str):
133+
for child in self.children:
134+
child_field = _wrap_as_field(
135+
child.name,
136+
child.create_type(strawberry_type),
137+
)
138+
self.fields_dict[strawberry_type].append(child_field)
139+
140+
return create_type(
141+
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type]
142+
)
143+
144+
145+
def _add_dev_attributes(
146+
fields_tree: FieldsTree,
147+
mapping: Mapping,
148+
) -> None:
149+
for single_mapping in mapping.get_controller_mappings():
150+
path = single_mapping.controller.path
151+
if path:
152+
node = fields_tree.insert(path)
153+
else:
154+
node = fields_tree
155+
156+
if node is not None:
157+
for attr_name, attribute in single_mapping.attributes.items():
158+
attr_name = attr_name.title().replace("_", "")
159+
160+
match attribute:
161+
# mutation for server changes https://graphql.org/learn/queries/
162+
case AttrRW():
163+
node.fields_dict["query"].append(
164+
strawberry.field(_wrap_attr_get(attr_name, attribute))
165+
)
166+
node.fields_dict["mutation"].append(
167+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
168+
)
169+
case AttrR():
170+
node.fields_dict["query"].append(
171+
strawberry.field(_wrap_attr_get(attr_name, attribute))
172+
)
173+
case AttrW():
174+
node.fields_dict["mutation"].append(
175+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
176+
)
177+
178+
179+
def _wrap_command(
180+
method_name: str, method: Callable, controller: BaseController
181+
) -> Callable[..., Awaitable[bool]]:
182+
async def _dynamic_f() -> bool:
183+
await MethodType(method, controller)()
184+
return True
185+
186+
_dynamic_f.__name__ = method_name
187+
188+
return _dynamic_f
189+
190+
191+
def _add_dev_commands(
192+
fields_tree: FieldsTree,
193+
mapping: Mapping,
194+
) -> None:
195+
for single_mapping in mapping.get_controller_mappings():
196+
path = single_mapping.controller.path
197+
if path:
198+
node = fields_tree.insert(path)
199+
else:
200+
node = fields_tree
201+
202+
if node is not None:
203+
for name, method in single_mapping.command_methods.items():
204+
cmd_name = name.title().replace("_", "")
205+
node.fields_dict["mutation"].append(
206+
strawberry.mutation(
207+
_wrap_command(
208+
cmd_name,
209+
method.fn,
210+
single_mapping.controller,
211+
)
212+
)
213+
)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
3+
import pytest
4+
5+
from fastcs.attributes import AttrR
6+
from fastcs.datatypes import Bool, Float, Int
7+
8+
9+
class TestGraphQLServer:
10+
@pytest.fixture(autouse=True)
11+
def _server_client(self, graphQL_client):
12+
self._client = graphQL_client
13+
14+
def client_read(self, field: str, value):
15+
query = f"query {{{field}}}"
16+
response = self._client.post("/graphql", json={"query": query})
17+
assert response.status_code == 200
18+
assert response.json()["data"] == {field: value}
19+
20+
def client_write(self, field: str, value):
21+
mutation = f"mutation {{{field}(value: {json.dumps(value)})}}"
22+
response = self._client.post("/graphql", json={"query": mutation})
23+
assert response.status_code == 200
24+
assert response.json()["data"] == {field: value}
25+
26+
def client_exec(self, field: str):
27+
mutation = f"mutation {{{field}}}"
28+
response = self._client.post("/graphql", json={"query": mutation})
29+
assert response.status_code == 200
30+
assert response.json()["data"] == {field: True}
31+
32+
def test_read_int(self):
33+
self.client_read("ReadInt", AttrR(Int())._value)
34+
35+
def test_read_write_int(self):
36+
self.client_read("ReadWriteInt", AttrR(Int())._value)
37+
self.client_write("ReadWriteInt", AttrR(Int())._value)
38+
39+
def test_read_write_float(self):
40+
self.client_read("ReadWriteFloat", AttrR(Float())._value)
41+
self.client_write("ReadWriteFloat", AttrR(Float())._value)
42+
43+
def test_read_bool(self):
44+
self.client_read("ReadBool", AttrR(Bool())._value)
45+
46+
def test_write_bool(self):
47+
self.client_write("WriteBool", AttrR(Bool())._value)
48+
49+
# # We need to discuss enums
50+
# def test_string_enum(self):
51+
52+
def test_big_enum(self):
53+
self.client_read(
54+
"BigEnum", AttrR(Int(), allowed_values=list(range(1, 18)))._value
55+
)
56+
57+
def test_go(self):
58+
self.client_exec("Go")
59+
60+
def test_read_child1(self):
61+
query = "query { SubController01 { ReadWriteInt }}"
62+
response = self._client.post("/graphql", json={"query": query})
63+
assert response.status_code == 200
64+
assert response.json()["data"] == {
65+
"SubController01": {"ReadWriteInt": AttrR(Int())._value}
66+
}
67+
68+
def test_read_child2(self):
69+
query = "query { SubController02 { ReadWriteInt }}"
70+
response = self._client.post("/graphql", json={"query": query})
71+
assert response.status_code == 200
72+
assert response.json()["data"] == {
73+
"SubController02": {"ReadWriteInt": AttrR(Int())._value}
74+
}

tests/conftest.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
import pytest
1010
from aioca import purge_channel_caches
11+
from fastapi.testclient import TestClient
1112

1213
from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater
13-
from fastcs.controller import Controller
14+
from fastcs.backends.graphQL.backend import GraphQLBackend
15+
from fastcs.controller import Controller, SubController
1416
from fastcs.datatypes import Bool, Float, Int, String
1517
from fastcs.mapping import Mapping
1618
from fastcs.wrappers import command, scan
@@ -49,7 +51,27 @@ class TestHandler(Handler, TestUpdater, TestSender):
4951
pass
5052

5153

54+
class TestSubController(SubController):
55+
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
56+
57+
def __init__(self) -> None:
58+
super().__init__()
59+
60+
@command()
61+
async def go(self):
62+
pass
63+
64+
5265
class TestController(Controller):
66+
def __init__(self) -> None:
67+
super().__init__()
68+
69+
self._sub_controllers: list[TestSubController] = []
70+
for index in range(1, 3):
71+
controller = TestSubController()
72+
self._sub_controllers.append(controller)
73+
self.register_sub_controller(f"SubController{index:02d}", controller)
74+
5375
read_int: AttrR = AttrR(Int(), handler=TestUpdater())
5476
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
5577
read_write_float: AttrRW = AttrRW(Float())
@@ -120,3 +142,10 @@ def ioc():
120142
except ValueError:
121143
# Someone else already called communicate
122144
pass
145+
146+
147+
@pytest.fixture(scope="class")
148+
def graphQL_client():
149+
app = GraphQLBackend(TestController())._server._app
150+
client = TestClient(app)
151+
return client

0 commit comments

Comments
 (0)