11
11
from graphene .types .argument import to_arguments
12
12
13
13
14
+ from .utils import get_model_reference_fields
15
+
16
+
14
17
# noqa
15
18
class MongoengineListField (Field ):
16
19
@@ -60,26 +63,35 @@ def model(self):
60
63
@property
61
64
def args (self ):
62
65
return to_arguments (
63
- self ._base_args or OrderedDict (), self .default_filter_args
66
+ self ._base_args or OrderedDict (),
67
+ dict (self .field_args , ** self .reference_args )
64
68
)
65
69
66
70
@args .setter
67
71
def args (self , args ):
68
72
self ._base_args = args
69
73
70
74
@property
71
- def default_filter_args (self ):
75
+ def field_args (self ):
72
76
def is_filterable (kv ):
73
77
return hasattr (kv [1 ], '_type' ) \
74
78
and callable (getattr (kv [1 ]._type , '_of_type' , None ))
75
79
76
80
return reduce (
77
81
lambda r , kv : r .update (
78
82
{kv [0 ]: kv [1 ]._type ._of_type ()}) or r if is_filterable (kv ) else r ,
79
- self .fields .items (),
80
- {}
83
+ self .fields .items (), {}
81
84
)
82
85
86
+ @property
87
+ def reference_args (self ):
88
+ def get_reference_field (r , kv ):
89
+ if callable (getattr (kv [1 ], 'get_type' , None )):
90
+ node = kv [1 ].get_type ()._type ._meta
91
+ r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
92
+ return r
93
+ return reduce (get_reference_field , self .fields .items (), {})
94
+
83
95
@property
84
96
def filter_fields (self ):
85
97
return self ._type ._meta .filter_fields
@@ -95,8 +107,17 @@ def get_query(cls, model, info, **args):
95
107
return []
96
108
97
109
objs = model .objects ()
98
-
99
110
if args :
111
+ reference_fields = get_model_reference_fields (model )
112
+ reference_args = {}
113
+ for arg_name , arg in args .copy ().items ():
114
+ if arg_name in reference_fields :
115
+ reference_model = model ._fields [arg_name ]
116
+ pk = from_global_id (args .pop (arg_name ))[- 1 ]
117
+ reference_obj = reference_model .document_type_obj .objects (pk = pk ).get ()
118
+ reference_args [arg_name ] = reference_obj
119
+
120
+ args .update (reference_args )
100
121
first = args .pop ('first' , None )
101
122
last = args .pop ('last' , None )
102
123
id = args .pop ('id' , None )
@@ -121,7 +142,7 @@ def get_query(cls, model, info, **args):
121
142
if first is not None :
122
143
objs = objs [:first ]
123
144
if last is not None :
124
- # fix for https://github.com/graphql-python/graphene-mongo/issues/20
145
+ # https://github.com/graphql-python/graphene-mongo/issues/20
125
146
objs = objs [- (last + 1 ):]
126
147
127
148
return objs
0 commit comments