Skip to content

Commit 99512c5

Browse files
tcleonardThomas Leonard
andauthored
fix: in and range filters on DjangoFilterConnectionField (#1070)
Co-authored-by: Thomas Leonard <[email protected]>
1 parent 7b35695 commit 99512c5

File tree

3 files changed

+202
-11
lines changed

3 files changed

+202
-11
lines changed

graphene_django/filter/fields.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
self._fields = fields
2222
self._provided_filterset_class = filterset_class
2323
self._filterset_class = None
24+
self._filtering_args = None
2425
self._extra_filter_meta = extra_filter_meta
2526
self._base_args = None
2627
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
@@ -50,7 +51,11 @@ def filterset_class(self):
5051

5152
@property
5253
def filtering_args(self):
53-
return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
54+
if not self._filtering_args:
55+
self._filtering_args = get_filtering_args_from_filterset(
56+
self.filterset_class, self.node_type
57+
)
58+
return self._filtering_args
5459

5560
@classmethod
5661
def resolve_queryset(
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import pytest
2+
3+
from graphene import ObjectType, Schema
4+
from graphene.relay import Node
5+
from graphene_django import DjangoObjectType
6+
from graphene_django.tests.models import Pet
7+
from graphene_django.utils import DJANGO_FILTER_INSTALLED
8+
9+
pytestmark = []
10+
11+
if DJANGO_FILTER_INSTALLED:
12+
from graphene_django.filter import DjangoFilterConnectionField
13+
else:
14+
pytestmark.append(
15+
pytest.mark.skipif(
16+
True, reason="django_filters not installed or not compatible"
17+
)
18+
)
19+
20+
21+
class PetNode(DjangoObjectType):
22+
class Meta:
23+
model = Pet
24+
interfaces = (Node,)
25+
filter_fields = {
26+
"name": ["exact", "in"],
27+
"age": ["exact", "in", "range"],
28+
}
29+
30+
31+
class Query(ObjectType):
32+
pets = DjangoFilterConnectionField(PetNode)
33+
34+
35+
def test_string_in_filter():
36+
"""
37+
Test in filter on a string field.
38+
"""
39+
Pet.objects.create(name="Brutus", age=12)
40+
Pet.objects.create(name="Mimi", age=3)
41+
Pet.objects.create(name="Jojo, the rabbit", age=3)
42+
43+
schema = Schema(query=Query)
44+
45+
query = """
46+
query {
47+
pets (name_In: ["Brutus", "Jojo, the rabbit"]) {
48+
edges {
49+
node {
50+
name
51+
}
52+
}
53+
}
54+
}
55+
"""
56+
result = schema.execute(query)
57+
assert not result.errors
58+
assert result.data["pets"]["edges"] == [
59+
{"node": {"name": "Brutus"}},
60+
{"node": {"name": "Jojo, the rabbit"}},
61+
]
62+
63+
64+
def test_int_in_filter():
65+
"""
66+
Test in filter on an integer field.
67+
"""
68+
Pet.objects.create(name="Brutus", age=12)
69+
Pet.objects.create(name="Mimi", age=3)
70+
Pet.objects.create(name="Jojo, the rabbit", age=3)
71+
72+
schema = Schema(query=Query)
73+
74+
query = """
75+
query {
76+
pets (age_In: [3]) {
77+
edges {
78+
node {
79+
name
80+
}
81+
}
82+
}
83+
}
84+
"""
85+
result = schema.execute(query)
86+
assert not result.errors
87+
assert result.data["pets"]["edges"] == [
88+
{"node": {"name": "Mimi"}},
89+
{"node": {"name": "Jojo, the rabbit"}},
90+
]
91+
92+
query = """
93+
query {
94+
pets (age_In: [3, 12]) {
95+
edges {
96+
node {
97+
name
98+
}
99+
}
100+
}
101+
}
102+
"""
103+
result = schema.execute(query)
104+
assert not result.errors
105+
assert result.data["pets"]["edges"] == [
106+
{"node": {"name": "Brutus"}},
107+
{"node": {"name": "Mimi"}},
108+
{"node": {"name": "Jojo, the rabbit"}},
109+
]
110+
111+
112+
def test_int_range_filter():
113+
"""
114+
Test in filter on an integer field.
115+
"""
116+
Pet.objects.create(name="Brutus", age=12)
117+
Pet.objects.create(name="Mimi", age=8)
118+
Pet.objects.create(name="Jojo, the rabbit", age=3)
119+
Pet.objects.create(name="Picotin", age=5)
120+
121+
schema = Schema(query=Query)
122+
123+
query = """
124+
query {
125+
pets (age_Range: [4, 9]) {
126+
edges {
127+
node {
128+
name
129+
}
130+
}
131+
}
132+
}
133+
"""
134+
result = schema.execute(query)
135+
assert not result.errors
136+
assert result.data["pets"]["edges"] == [
137+
{"node": {"name": "Mimi"}},
138+
{"node": {"name": "Picotin"}},
139+
]

graphene_django/filter/utils.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import six
22

3+
from graphene import List
4+
35
from django_filters.utils import get_model_field
6+
from django_filters.filters import Filter, BaseCSVFilter
7+
48
from .filterset import custom_filterset_factory, setup_filterset
59

610

@@ -17,31 +21,74 @@ def get_filtering_args_from_filterset(filterset_class, type):
1721
form_field = None
1822

1923
if name in filterset_class.declared_filters:
24+
# Get the filter field from the explicitly declared filter
2025
form_field = filter_field.field
26+
field = convert_form_field(form_field)
2127
else:
28+
# Get the filter field with no explicit type declaration
2229
model_field = get_model_field(model, filter_field.field_name)
2330
filter_type = filter_field.lookup_expr
2431
if filter_type != "isnull" and hasattr(model_field, "formfield"):
2532
form_field = model_field.formfield(
2633
required=filter_field.extra.get("required", False)
2734
)
2835

29-
# Fallback to field defined on filter if we can't get it from the
30-
# model field
31-
if not form_field:
32-
form_field = filter_field.field
36+
# Fallback to field defined on filter if we can't get it from the
37+
# model field
38+
if not form_field:
39+
form_field = filter_field.field
40+
41+
field = convert_form_field(form_field)
3342

34-
field_type = convert_form_field(form_field).Argument()
43+
if filter_type in ["in", "range"]:
44+
# Replace CSV filters (`in`, `range`) argument type to be a list of the same type as the field.
45+
# See comments in `replace_csv_filters` method for more details.
46+
field = List(field.get_type())
47+
48+
field_type = field.Argument()
3549
field_type.description = filter_field.label
3650
args[name] = field_type
3751

3852
return args
3953

4054

4155
def get_filterset_class(filterset_class, **meta):
42-
"""Get the class to be used as the FilterSet"""
56+
"""
57+
Get the class to be used as the FilterSet.
58+
"""
4359
if filterset_class:
44-
# If were given a FilterSet class, then set it up and
45-
# return it
46-
return setup_filterset(filterset_class)
47-
return custom_filterset_factory(**meta)
60+
# If were given a FilterSet class, then set it up.
61+
graphene_filterset_class = setup_filterset(filterset_class)
62+
else:
63+
# Otherwise create one.
64+
graphene_filterset_class = custom_filterset_factory(**meta)
65+
66+
replace_csv_filters(graphene_filterset_class)
67+
return graphene_filterset_class
68+
69+
70+
def replace_csv_filters(filterset_class):
71+
"""
72+
Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
73+
but regular Filter objects that simply use the input value as filter argument on the queryset.
74+
75+
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
76+
can actually have a list as input and have a proper type verification of each value in the list.
77+
78+
See issue https://github.com/graphql-python/graphene-django/issues/1068.
79+
"""
80+
for name, filter_field in six.iteritems(filterset_class.base_filters):
81+
filter_type = filter_field.lookup_expr
82+
if (
83+
filter_type in ["in", "range"]
84+
and name not in filterset_class.declared_filters
85+
):
86+
assert isinstance(filter_field, BaseCSVFilter)
87+
filterset_class.base_filters[name] = Filter(
88+
field_name=filter_field.field_name,
89+
lookup_expr=filter_field.lookup_expr,
90+
label=filter_field.label,
91+
method=filter_field.method,
92+
exclude=filter_field.exclude,
93+
**filter_field.extra
94+
)

0 commit comments

Comments
 (0)