Skip to content

Commit cc4bbac

Browse files
committed
refact: converter.py resolver code by separating resolvers to separate classes
1 parent 07101f6 commit cc4bbac

File tree

7 files changed

+554
-516
lines changed

7 files changed

+554
-516
lines changed

graphene_mongo/converter.py

Lines changed: 46 additions & 495 deletions
Large diffs are not rendered by default.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .dynamic_lazy_field_resolver import DynamicLazyFieldResolver
2+
from .dynamic_reference_field_resolver import DynamicReferenceFieldResolver
3+
from .list_field_resolver import ListFieldResolver
4+
from .union_resolver import UnionFieldResolver
5+
6+
__all__ = [
7+
"DynamicLazyFieldResolver",
8+
"DynamicReferenceFieldResolver",
9+
"ListFieldResolver",
10+
"UnionFieldResolver",
11+
]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from collections.abc import Callable
2+
from typing import Optional, Union
3+
4+
from bson import ObjectId
5+
from graphene.utils.str_converters import to_snake_case
6+
from graphene_mongo.utils import (
7+
ExecutorEnum,
8+
get_query_fields,
9+
sync_to_async,
10+
)
11+
from mongoengine import Document
12+
13+
14+
class DynamicLazyFieldResolver:
15+
@staticmethod
16+
def __lazy_resolver_common(
17+
field, registry, executor: ExecutorEnum, root, *args, **kwargs
18+
) -> Optional[Union[tuple[Document, set[str], ObjectId], Document]]:
19+
document = getattr(root, field.name or field.db_name)
20+
if not document:
21+
return None
22+
if document._cached_doc:
23+
return document._cached_doc
24+
25+
queried_fields = []
26+
_type = registry.get_type_for_model(document.document_type, executor=executor)
27+
filter_args = []
28+
if _type._meta.filter_fields:
29+
for key, values in _type._meta.filter_fields.items():
30+
for each in values:
31+
filter_args.append(key + "__" + each)
32+
for each in get_query_fields(args[0]).keys():
33+
item = to_snake_case(each)
34+
if item in document.document_type._fields_ordered + tuple(filter_args):
35+
queried_fields.append(item)
36+
37+
only_fields = set((list(_type._meta.required_fields) + queried_fields))
38+
39+
return document.document_type, only_fields, document.id
40+
41+
@staticmethod
42+
def lazy_resolver(field, registry, executor) -> Callable:
43+
def resolver(root, *args, **kwargs) -> Optional[Document]:
44+
result = DynamicLazyFieldResolver.__lazy_resolver_common(
45+
field, registry, executor, root, *args, **kwargs
46+
)
47+
if not isinstance(result, tuple):
48+
return result
49+
document, only_fields, pk = result
50+
return document.objects.no_dereference().only(*only_fields).get(pk=pk)
51+
52+
return resolver
53+
54+
@staticmethod
55+
def lazy_resolver_async(field, registry, executor) -> Callable:
56+
async def resolver(root, *args, **kwargs) -> Optional[Document]:
57+
result = DynamicLazyFieldResolver.__lazy_resolver_common(
58+
field, registry, executor, root, *args, **kwargs
59+
)
60+
if not isinstance(result, tuple):
61+
return result
62+
document, only_fields, pk = result
63+
return await sync_to_async(document.objects.no_dereference().only(*only_fields).get)(
64+
pk=pk
65+
)
66+
67+
return resolver
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from collections.abc import Callable
2+
from typing import Optional, Union
3+
4+
from bson import ObjectId
5+
from graphene.utils.str_converters import to_snake_case
6+
from graphene_mongo.utils import (
7+
ExecutorEnum,
8+
get_query_fields,
9+
sync_to_async,
10+
)
11+
from mongoengine import Document, ReferenceField
12+
13+
14+
class DynamicReferenceFieldResolver:
15+
@staticmethod
16+
def __reference_resolver_common(
17+
field, registry, executor: ExecutorEnum, root, *args, **kwargs
18+
) -> Optional[Union[tuple[Document, set[str], ObjectId], Document]]:
19+
document = root._data.get(field.name or field.db_name, None)
20+
if not document:
21+
return None
22+
23+
queried_fields = list()
24+
_type = registry.get_type_for_model(field.document_type, executor=executor)
25+
filter_args = list()
26+
if _type._meta.filter_fields:
27+
for key, values in _type._meta.filter_fields.items():
28+
for each in values:
29+
filter_args.append(key + "__" + each)
30+
for each in get_query_fields(args[0]).keys():
31+
item = to_snake_case(each)
32+
if item in field.document_type._fields_ordered + tuple(filter_args):
33+
queried_fields.append(item)
34+
35+
fields_to_fetch = set(list(_type._meta.required_fields) + queried_fields)
36+
if isinstance(document, field.document_type) and all(
37+
document._data[_field] is not None for _field in fields_to_fetch
38+
):
39+
return document # Data is already fetched
40+
41+
document_id = (
42+
document.id
43+
if isinstance(field, ReferenceField)
44+
else getattr(root, field.name or field.db_name)
45+
)
46+
return field.document_type, fields_to_fetch, document_id
47+
48+
@staticmethod
49+
def reference_resolver(field, registry, executor) -> Callable:
50+
def resolver(root, *args, **kwargs) -> Optional[Document]:
51+
result = DynamicReferenceFieldResolver.__reference_resolver_common(
52+
field, registry, executor, root, *args, **kwargs
53+
)
54+
if not isinstance(result, tuple):
55+
return result
56+
document, only_fields, pk = result
57+
return document.objects.no_dereference().only(*only_fields).get(pk=pk)
58+
59+
return resolver
60+
61+
@staticmethod
62+
def reference_resolver_async(field, registry, executor) -> Callable:
63+
async def resolver(root, *args, **kwargs) -> Optional[Document]:
64+
result = DynamicReferenceFieldResolver.__reference_resolver_common(
65+
field, registry, executor, root, *args, **kwargs
66+
)
67+
if not isinstance(result, tuple):
68+
return result
69+
document, only_fields, pk = result
70+
return await sync_to_async(document.objects.no_dereference().only(*only_fields).get)(
71+
pk=pk
72+
)
73+
74+
return resolver
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import asyncio
2+
from asyncio import Future, Task
3+
from collections.abc import Callable
4+
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
from typing import Optional, Union
6+
7+
from bson import ObjectId
8+
from graphene.utils.str_converters import to_snake_case
9+
from graphene_mongo.utils import (
10+
ExecutorEnum,
11+
get_queried_union_types,
12+
sync_to_async,
13+
)
14+
import mongoengine
15+
from mongoengine import Document
16+
from mongoengine.base import LazyReference, get_document
17+
18+
19+
class ListFieldResolver:
20+
@staticmethod
21+
def __get_reference_objects_common(
22+
registry,
23+
model,
24+
executor: ExecutorEnum,
25+
object_id_list: list[ObjectId],
26+
queried_fields: dict,
27+
) -> tuple[Document, set[str], list[ObjectId]]:
28+
from graphene_mongo.converter import convert_mongoengine_field
29+
30+
document = get_document(model)
31+
document_field = mongoengine.ReferenceField(document)
32+
document_field = convert_mongoengine_field(document_field, registry, executor)
33+
document_field_type = document_field.get_type().type
34+
_queried_fields = list()
35+
filter_args = list()
36+
if document_field_type._meta.filter_fields:
37+
for key, values in document_field_type._meta.filter_fields.items():
38+
for each in values:
39+
filter_args.append(key + "__" + each)
40+
for each in queried_fields:
41+
item = to_snake_case(each)
42+
if item in document._fields_ordered + tuple(filter_args):
43+
_queried_fields.append(item)
44+
45+
only_fields = set(list(document_field_type._meta.required_fields) + _queried_fields)
46+
return document, only_fields, object_id_list
47+
48+
# ======================= DB CALLS =======================
49+
@staticmethod
50+
def __get_reference_objects(
51+
registry,
52+
model,
53+
executor: ExecutorEnum,
54+
object_id_list: list[ObjectId],
55+
queried_fields: dict,
56+
):
57+
document, only_fields, document_ids = ListFieldResolver.__get_reference_objects_common(
58+
registry, model, executor, object_id_list, queried_fields
59+
)
60+
return document.objects().no_dereference().only(*only_fields).filter(pk__in=document_ids)
61+
62+
@staticmethod
63+
async def __get_reference_objects_async(
64+
registry,
65+
model,
66+
executor: ExecutorEnum,
67+
object_id_list: list[ObjectId],
68+
queried_fields: dict,
69+
):
70+
document, only_fields, document_ids = ListFieldResolver.__get_reference_objects_common(
71+
registry, model, executor, object_id_list, queried_fields
72+
)
73+
return await sync_to_async(list)(
74+
document.objects().no_dereference().only(*only_fields).filter(pk__in=document_ids)
75+
)
76+
77+
# ======================= DB CALLS: END =======================
78+
79+
@staticmethod
80+
def __get_non_querying_object(model, object_id_list) -> list[Document]:
81+
model = get_document(model)
82+
return [model(pk=each) for each in object_id_list]
83+
84+
@staticmethod
85+
async def __get_non_querying_object_async(model, object_id_list) -> list[Document]:
86+
return ListFieldResolver.__get_non_querying_object(model, object_id_list)
87+
88+
@staticmethod
89+
def __build_results(
90+
result: list[Document], to_resolve_object_ids: list[ObjectId]
91+
) -> list[Document]:
92+
result_object: dict[ObjectId, Document] = {}
93+
for items in result:
94+
for item in items:
95+
result_object[item.id] = item
96+
return [result_object[each] for each in to_resolve_object_ids]
97+
98+
# ======================= Main Logic =======================
99+
100+
@staticmethod
101+
def __reference_resolver_common(
102+
field, registry, executor: ExecutorEnum, root, *args, **kwargs
103+
) -> Optional[tuple[Union[list[Task], list[Document]], list[ObjectId]]]:
104+
to_resolve = getattr(root, field.name or field.db_name)
105+
if not to_resolve:
106+
return None
107+
108+
choice_to_resolve = dict()
109+
registry_string_map = (
110+
registry._registry_string_map
111+
if executor == ExecutorEnum.SYNC
112+
else registry._registry_async_string_map
113+
)
114+
querying_union_types = get_queried_union_types(
115+
info=args[0], valid_gql_types=registry_string_map.keys()
116+
)
117+
to_resolve_models = dict()
118+
for each, queried_fields in querying_union_types.items():
119+
to_resolve_models[registry_string_map[each]] = queried_fields
120+
to_resolve_object_ids: list[ObjectId] = list()
121+
for each in to_resolve:
122+
if isinstance(each, LazyReference):
123+
to_resolve_object_ids.append(each.pk)
124+
model = each.document_type._class_name
125+
if model not in choice_to_resolve:
126+
choice_to_resolve[model] = list()
127+
choice_to_resolve[model].append(each.pk)
128+
else:
129+
to_resolve_object_ids.append(each["_ref"].id)
130+
if each["_cls"] not in choice_to_resolve:
131+
choice_to_resolve[each["_cls"]] = list()
132+
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
133+
134+
if executor == ExecutorEnum.SYNC:
135+
pool = ThreadPoolExecutor(5)
136+
futures: list[Future] = list()
137+
for model, object_id_list in choice_to_resolve.items():
138+
if model in to_resolve_models:
139+
queried_fields = to_resolve_models[model]
140+
futures.append(
141+
pool.submit(
142+
ListFieldResolver.__get_reference_objects,
143+
*(registry, model, executor, object_id_list, queried_fields),
144+
)
145+
)
146+
else:
147+
futures.append(
148+
pool.submit(
149+
ListFieldResolver.__get_non_querying_object,
150+
*(model, object_id_list),
151+
)
152+
)
153+
result = [future.result() for future in as_completed(futures)]
154+
return result, to_resolve_object_ids
155+
else:
156+
loop = asyncio.get_event_loop()
157+
tasks: list[Task] = []
158+
for model, object_id_list in choice_to_resolve.items():
159+
if model in to_resolve_models:
160+
queried_fields = to_resolve_models[model]
161+
task = loop.create_task(
162+
ListFieldResolver.__get_reference_objects_async(
163+
registry, model, executor, object_id_list, queried_fields
164+
)
165+
)
166+
else:
167+
task = loop.create_task(
168+
ListFieldResolver.__get_non_querying_object_async(model, object_id_list)
169+
)
170+
tasks.append(task)
171+
return tasks, to_resolve_object_ids
172+
173+
@staticmethod
174+
def reference_resolver(field, registry, executor) -> Callable:
175+
def resolver(root, *args, **kwargs) -> Optional[list[Document]]:
176+
resolver_result = ListFieldResolver.__reference_resolver_common(
177+
field, registry, executor, root, *args, **kwargs
178+
)
179+
if not isinstance(resolver_result, tuple):
180+
return resolver_result
181+
result, to_resolve_object_ids = resolver_result
182+
return ListFieldResolver.__build_results(result, to_resolve_object_ids)
183+
184+
return resolver
185+
186+
@staticmethod
187+
def reference_resolver_async(field, registry, executor) -> Callable:
188+
async def resolver(root, *args, **kwargs) -> Optional[list[Document]]:
189+
resolver_result = ListFieldResolver.__reference_resolver_common(
190+
field, registry, executor, root, *args, **kwargs
191+
)
192+
if not isinstance(resolver_result, tuple):
193+
return resolver_result
194+
tasks, to_resolve_object_ids = resolver_result
195+
result: list[Document] = await asyncio.gather(*tasks)
196+
return ListFieldResolver.__build_results(result, to_resolve_object_ids)
197+
198+
return resolver
199+
200+
# ======================= Main Logic: END =======================

0 commit comments

Comments
 (0)