7
7
from mongoengine .base import get_document
8
8
from . import advanced_types
9
9
from .utils import import_single_dispatch , get_field_description , get_query_fields
10
+ from concurrent .futures import ThreadPoolExecutor , wait , as_completed
10
11
11
12
singledispatch = import_single_dispatch ()
12
13
@@ -104,6 +105,46 @@ def convert_file_to_field(field, registry=None):
104
105
def convert_field_to_list (field , registry = None ):
105
106
base_type = convert_mongoengine_field (field .field , registry = registry )
106
107
if isinstance (base_type , graphene .Field ):
108
+ if isinstance (field .field , mongoengine .GenericReferenceField ):
109
+ def get_reference_objects (* args , ** kwargs ):
110
+ if args [0 ][1 ]:
111
+ document = get_document (args [0 ][0 ])
112
+ document_field = mongoengine .ReferenceField (document )
113
+ document_field = convert_mongoengine_field (document_field , registry )
114
+ document_field_type = document_field .get_type ().type ._meta .name
115
+ only_fields = [to_snake_case (i ) for i in get_query_fields (args [0 ][3 ])[document_field_type ].keys ()]
116
+ return document .objects ().no_dereference ().only (* only_fields ).filter (pk__in = args [0 ][1 ])
117
+ else :
118
+ return []
119
+
120
+ def reference_resolver (root , * args , ** kwargs ):
121
+ choice_to_resolve = dict ()
122
+ to_resolve = getattr (root , field .name or field .db_name )
123
+ for each in to_resolve :
124
+ if each ['_cls' ] not in choice_to_resolve :
125
+ choice_to_resolve [each ['_cls' ]] = list ()
126
+ choice_to_resolve [each ['_cls' ]].append (each ["_ref" ].id )
127
+
128
+ pool = ThreadPoolExecutor (5 )
129
+ futures = list ()
130
+ for model , object_id_list in choice_to_resolve .items ():
131
+ futures .append (pool .submit (get_reference_objects , (model , object_id_list , registry , * args )))
132
+ result = list ()
133
+ for x in as_completed (futures ):
134
+ result += x .result ()
135
+ to_resolve_object_ids = [each ["_ref" ].id for each in to_resolve ]
136
+ result_to_resolve_object_ids = [each .id for each in result ]
137
+ ordered_result = list ()
138
+ for each in to_resolve_object_ids :
139
+ ordered_result .append (result [result_to_resolve_object_ids .index (each )])
140
+ return ordered_result
141
+
142
+ return graphene .List (
143
+ base_type ._type ,
144
+ description = get_field_description (field , registry ),
145
+ required = field .required ,
146
+ resolver = reference_resolver
147
+ )
107
148
return graphene .List (
108
149
base_type ._type ,
109
150
description = get_field_description (field , registry ),
@@ -128,7 +169,7 @@ def convert_field_to_list(field, registry=None):
128
169
return graphene .List (
129
170
base_type ,
130
171
description = get_field_description (field , registry ),
131
- required = field .required
172
+ required = field .required ,
132
173
)
133
174
134
175
@@ -172,7 +213,8 @@ def reference_resolver(root, *args, **kwargs):
172
213
return document .objects ().no_dereference ().only (* only_fields ).get (pk = dereferenced ["_ref" ].id )
173
214
174
215
if isinstance (field , mongoengine .GenericReferenceField ):
175
- return graphene .Field (_union , resolver = reference_resolver )
216
+ return graphene .Field (_union , resolver = reference_resolver ,
217
+ description = get_field_description (field , registry ))
176
218
177
219
return graphene .Field (_union )
178
220
0 commit comments