Skip to content

Commit b74b540

Browse files
committed
Add check point
1 parent 09a9f49 commit b74b540

File tree

3 files changed

+145
-49
lines changed

3 files changed

+145
-49
lines changed

src/fastcs/backends/graphQL/graphQL.py

Lines changed: 107 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,17 @@ class GraphQLServerOptions:
2525
class GraphQLServer:
2626
def __init__(self, mapping: Mapping):
2727
self._mapping = mapping
28-
self._field_dict: dict[str, list[StrawberryField]] = {
29-
"Query": [],
30-
"Mutation": [],
31-
}
28+
self._fields_tree: FieldsTree = FieldsTree("")
3229
self._app = self._create_app()
3330

3431
def _create_app(self) -> FastAPI:
35-
_add_dev_attributes(self._field_dict, self._mapping)
36-
_add_dev_commands(self._field_dict, self._mapping)
32+
_add_dev_attributes(self._fields_tree, self._mapping)
33+
_add_dev_commands(self._fields_tree, self._mapping)
3734

3835
schema_kwargs = {}
39-
for key, value in self._field_dict.items():
40-
if self._field_dict[key]:
41-
# Strawberry types map to graphql object
42-
schema_kwargs[key.lower()] = create_type(key, value)
36+
for key in ["query", "mutation"]:
37+
if s_type := self._fields_tree.create_type(key):
38+
schema_kwargs[key] = s_type
4339
schema = strawberry.Schema(**schema_kwargs) # type: ignore
4440
graphql_app: GraphQL = GraphQL(schema)
4541

@@ -90,36 +86,94 @@ async def _dynamic_f() -> Any:
9086
return _dynamic_f
9187

9288

93-
def _add_dev_attributes(
94-
field_dict: dict[str, list[StrawberryField]], mapping: Mapping
95-
) -> None:
96-
for single_mapping in mapping.get_controller_mappings():
97-
path = single_mapping.controller.path
98-
# nest for each controller
99-
# if path:
89+
def _wrap_as_field(
90+
field_name: str,
91+
strawberry_type: type,
92+
) -> StrawberryField:
93+
def _dynamic_field():
94+
return strawberry_type()
10095

101-
for attr_name, attribute in single_mapping.attributes.items():
102-
attr_name = attr_name.title().replace("_", "")
103-
d_attr_name = f"{'/'.join(path)}/{attr_name}" if path else attr_name
96+
_dynamic_field.__name__ = field_name
97+
_dynamic_field.__annotations__["return"] = strawberry_type
10498

105-
match attribute:
106-
case AttrRW():
107-
field_dict["Query"].append(
108-
strawberry.field(_wrap_attr_get(d_attr_name, attribute))
109-
)
110-
field_dict["Mutation"].append(
111-
strawberry.mutation(_wrap_attr_set(d_attr_name, attribute))
112-
) # mutation for server changes https://graphql.org/learn/queries/
99+
return strawberry.field(_dynamic_field)
113100

114-
case AttrR():
115-
field_dict["Query"].append(
116-
strawberry.field(_wrap_attr_get(d_attr_name, attribute))
117-
)
118101

119-
case AttrW():
120-
field_dict["Mutation"].append(
121-
strawberry.mutation(_wrap_attr_set(d_attr_name, attribute))
122-
)
102+
class FieldsTree:
103+
def __init__(self, name: str):
104+
self.name = name
105+
self.children: list[FieldsTree] = []
106+
self.fields_dict: dict[str, list[StrawberryField]] = {
107+
"query": [],
108+
"mutation": [],
109+
}
110+
111+
def insert(self, path: list[str]):
112+
# Create child if not exist
113+
name = path.pop(0)
114+
if not (child := self.get_child(name)):
115+
child = FieldsTree(name)
116+
self.children.append(child)
117+
else:
118+
child = self.get_child(name)
119+
120+
# Recurse if needed
121+
if path:
122+
return child.insert(path) # type: ignore
123+
else:
124+
return child
125+
126+
def get_child(self, name: str):
127+
for child in self.children:
128+
if child.name == name:
129+
return child
130+
return None
131+
132+
def create_type(self, strawberry_type: str):
133+
for child in self.children:
134+
child_field = _wrap_as_field(
135+
child.name,
136+
child.create_type(strawberry_type),
137+
)
138+
self.fields_dict[strawberry_type].append(child_field)
139+
140+
return create_type(
141+
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type]
142+
)
143+
144+
145+
def _add_dev_attributes(
146+
fields_tree: FieldsTree,
147+
mapping: Mapping,
148+
) -> None:
149+
for single_mapping in mapping.get_controller_mappings():
150+
path = single_mapping.controller.path
151+
if path:
152+
node = fields_tree.insert(path)
153+
else:
154+
node = fields_tree
155+
156+
if node is not None:
157+
for attr_name, attribute in single_mapping.attributes.items():
158+
attr_name = attr_name.title().replace("_", "")
159+
160+
match attribute:
161+
# mutation for server changes https://graphql.org/learn/queries/
162+
case AttrRW():
163+
node.fields_dict["query"].append(
164+
strawberry.field(_wrap_attr_get(attr_name, attribute))
165+
)
166+
node.fields_dict["mutation"].append(
167+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
168+
)
169+
case AttrR():
170+
node.fields_dict["query"].append(
171+
strawberry.field(_wrap_attr_get(attr_name, attribute))
172+
)
173+
case AttrW():
174+
node.fields_dict["mutation"].append(
175+
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
176+
)
123177

124178

125179
def _wrap_command(
@@ -135,20 +189,25 @@ async def _dynamic_f() -> bool:
135189

136190

137191
def _add_dev_commands(
138-
field_dict: dict[str, list[StrawberryField]], mapping: Mapping
192+
fields_tree: FieldsTree,
193+
mapping: Mapping,
139194
) -> None:
140195
for single_mapping in mapping.get_controller_mappings():
141196
path = single_mapping.controller.path
142-
143-
for name, method in single_mapping.command_methods.items():
144-
cmd_name = name.title().replace("_", "")
145-
d_cmd_name = f"{'/'.join(path)}/{cmd_name}" if path else cmd_name
146-
field_dict["Mutation"].append(
147-
strawberry.mutation(
148-
_wrap_command(
149-
d_cmd_name,
150-
method.fn,
151-
single_mapping.controller,
197+
if path:
198+
node = fields_tree.insert(path)
199+
else:
200+
node = fields_tree
201+
202+
if node is not None:
203+
for name, method in single_mapping.command_methods.items():
204+
cmd_name = name.title().replace("_", "")
205+
node.fields_dict["mutation"].append(
206+
strawberry.mutation(
207+
_wrap_command(
208+
cmd_name,
209+
method.fn,
210+
single_mapping.controller,
211+
)
152212
)
153213
)
154-
)

tests/backends/graphQL/test_graphQL.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def client_exec(self, field: str):
2727
mutation = f"mutation {{{field}}}"
2828
response = self._client.post("/graphql", json={"query": mutation})
2929
assert response.status_code == 200
30+
assert response.json()["data"] == {field: True}
3031

3132
def test_read_int(self):
3233
self.client_read("ReadInt", AttrR(Int())._value)
@@ -55,3 +56,19 @@ def test_big_enum(self):
5556

5657
def test_go(self):
5758
self.client_exec("Go")
59+
60+
def test_read_child1(self):
61+
query = "query { SubController01 { ReadWriteInt }}"
62+
response = self._client.post("/graphql", json={"query": query})
63+
assert response.status_code == 200
64+
assert response.json()["data"] == {
65+
"SubController01": {"ReadWriteInt": AttrR(Int())._value}
66+
}
67+
68+
def test_read_child2(self):
69+
query = "query { SubController02 { ReadWriteInt }}"
70+
response = self._client.post("/graphql", json={"query": query})
71+
assert response.status_code == 200
72+
assert response.json()["data"] == {
73+
"SubController02": {"ReadWriteInt": AttrR(Int())._value}
74+
}

tests/conftest.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater
1414
from fastcs.backends.graphQL.backend import GraphQLBackend
15-
from fastcs.controller import Controller
15+
from fastcs.controller import Controller, SubController
1616
from fastcs.datatypes import Bool, Float, Int, String
1717
from fastcs.mapping import Mapping
1818
from fastcs.wrappers import command, scan
@@ -51,7 +51,27 @@ class TestHandler(Handler, TestUpdater, TestSender):
5151
pass
5252

5353

54+
class TestSubController(SubController):
55+
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
56+
57+
def __init__(self) -> None:
58+
super().__init__()
59+
60+
@command()
61+
async def go(self):
62+
pass
63+
64+
5465
class TestController(Controller):
66+
def __init__(self) -> None:
67+
super().__init__()
68+
69+
self._sub_controllers: list[TestSubController] = []
70+
for index in range(1, 3):
71+
controller = TestSubController()
72+
self._sub_controllers.append(controller)
73+
self.register_sub_controller(f"SubController{index:02d}", controller)
74+
5575
read_int: AttrR = AttrR(Int(), handler=TestUpdater())
5676
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
5777
read_write_float: AttrRW = AttrRW(Float())

0 commit comments

Comments
 (0)