Skip to content

Commit 5d978e8

Browse files
committed
feat: Add MongoengineConnectionField.reference_args
1 parent f2ce892 commit 5d978e8

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

graphene_mongo/fields.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,34 @@ def model(self):
6060
@property
6161
def args(self):
6262
return to_arguments(
63-
self._base_args or OrderedDict(), self.default_filter_args
63+
self._base_args or OrderedDict(), dict(self.field_args.items() + self.reference_args.items())
6464
)
6565

6666
@args.setter
6767
def args(self, args):
6868
self._base_args = args
6969

7070
@property
71-
def default_filter_args(self):
71+
def field_args(self):
7272
def is_filterable(kv):
7373
return hasattr(kv[1], '_type') \
7474
and callable(getattr(kv[1]._type, '_of_type', None))
7575

7676
return reduce(
7777
lambda r, kv: r.update(
7878
{kv[0]: kv[1]._type._of_type()}) or r if is_filterable(kv) else r,
79-
self.fields.items(),
80-
{}
79+
self.fields.items(), {}
8180
)
8281

82+
@property
83+
def reference_args(self):
84+
def get_reference_field(r, kv):
85+
if callable(getattr(kv[1], 'get_type', None)):
86+
node = kv[1].get_type()._type._meta
87+
r.update({kv[0]: node.fields['id']._type.of_type()})
88+
return r
89+
return reduce(get_reference_field, self.fields.items(), {})
90+
8391
@property
8492
def filter_fields(self):
8593
return self._type._meta.filter_fields
@@ -121,7 +129,7 @@ def get_query(cls, model, info, **args):
121129
if first is not None:
122130
objs = objs[:first]
123131
if last is not None:
124-
# fix for https://github.com/graphql-python/graphene-mongo/issues/20
132+
# https://github.com/graphql-python/graphene-mongo/issues/20
125133
objs = objs[-(last+1):]
126134

127135
return objs

graphene_mongo/tests/test_fields.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
from ..fields import MongoengineConnectionField
4+
from .types import ArticleNode
5+
6+
7+
def test_field_args():
8+
field = MongoengineConnectionField(ArticleNode)
9+
10+
field_args = ['id', 'headline', 'pub_date']
11+
assert set(field.field_args.keys()) == set(field_args)
12+
13+
reference_args = ['editor', 'reporter']
14+
assert set(field.reference_args.keys()) == set(reference_args)
15+
16+
default_args = ['after', 'last', 'first', 'before']
17+
args = field_args + reference_args + default_args
18+
assert set(field.args) == set(args)

graphene_mongo/tests/test_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,20 @@
55
from ..utils import (
66
get_model_fields, is_valid_mongoengine_model
77
)
8-
from .models import Reporter, Child
8+
from .models import Article, Reporter, Child
9+
910

1011
def test_get_model_fields_no_duplication():
1112
reporter_fields = get_model_fields(Reporter)
1213
reporter_name_set = set(reporter_fields)
1314
assert len(reporter_fields) == len(reporter_name_set)
1415

1516

17+
def test_get_model_relation_fields():
18+
article_fields = get_model_fields(Article)
19+
assert all(field in set(article_fields) for field in ['editor', 'reporter'])
20+
21+
1622
def test_get_base_model_fields():
1723
child_fields = get_model_fields(Child)
1824
assert all(field in set(child_fields) for field in ['bar', 'baz'])

0 commit comments

Comments
 (0)