Skip to content

Commit b7297ba

Browse files
committed
Add nested field search query support
1 parent 2bfec6e commit b7297ba

File tree

4 files changed

+88
-6
lines changed

4 files changed

+88
-6
lines changed

examples/simple/search_indexes/documents/city.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class CityDocument(DocType):
5353
# ********************************************************************
5454

5555
# City object
56-
country = fields.ObjectField(
56+
country = fields.NestedField(
5757
properties={
5858
'name': StringField(
5959
analyzer=html_strip,

examples/simple/search_indexes/viewsets/city.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ class CityDocumentViewSet(BaseDocumentViewSet):
4646
search_fields = (
4747
'name',
4848
'info',
49-
'country.name',
5049
)
50+
51+
search_nested_fields = {
52+
'country': ['name'],
53+
}
54+
5155
# Define filtering fields
5256
filter_fields = {
5357
'id': None,

src/django_elasticsearch_dsl_drf/filter_backends/search.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class SearchFilterBackend(BaseFilterBackend, FilterBackendMixin):
4343
>>> 'title',
4444
>>> 'content',
4545
>>> )
46+
>>> search_nested_fields = {
47+
>>> 'state': ['name'],
48+
>>> 'documents.author': ['title', 'description'],
49+
>>> }
4650
"""
4751

4852
search_param = api_settings.SEARCH_PARAM
@@ -58,6 +62,41 @@ def get_search_query_params(self, request):
5862
query_params = request.query_params.copy()
5963
return query_params.getlist(self.search_param, [])
6064

65+
def construct_nested_search(self, request, view):
66+
"""Construct nested search.
67+
68+
:param request: Django REST framework request.
69+
:param queryset: Base queryset.
70+
:param view: View.
71+
:type request: rest_framework.request.Request
72+
:type queryset: elasticsearch_dsl.search.Search
73+
:type view: rest_framework.viewsets.ReadOnlyModelViewSet
74+
:return: Updated queryset.
75+
:rtype: elasticsearch_dsl.search.Search
76+
"""
77+
if not hasattr(view, 'search_nested_fields'):
78+
return []
79+
80+
query_params = self.get_search_query_params(request)
81+
__queries = []
82+
for search_term in query_params:
83+
for path, fields in view.search_nested_fields.items():
84+
queries = []
85+
for field in fields:
86+
field_key = "{}.{}".format(path, field)
87+
queries.append(
88+
Q("match", **{field_key: search_term})
89+
)
90+
91+
__queries.append(
92+
Q("nested",
93+
path=path,
94+
query=six.moves.reduce(operator.or_, queries)
95+
)
96+
)
97+
98+
return __queries
99+
61100
def construct_search(self, request, view):
62101
"""Construct search.
63102
@@ -99,10 +138,9 @@ def filter_queryset(self, request, queryset, view):
99138
:return: Updated queryset.
100139
:rtype: elasticsearch_dsl.search.Search
101140
"""
102-
__queries = self.construct_search(request, view)
141+
__queries = self.construct_search(request, view) +\
142+
self.construct_nested_search(request, view)
103143

104144
if __queries:
105-
queryset = queryset.query(
106-
six.moves.reduce(operator.or_, __queries)
107-
)
145+
queryset = queryset.query('bool', should=__queries)
108146
return queryset

src/django_elasticsearch_dsl_drf/tests/test_search.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ def setUp(cls):
5656
)
5757

5858
cls.all_count = cls.special_count + cls.lorem_count
59+
60+
cls.cities_count = 20
61+
cls.cities = factories.CityFactory.create_batch(
62+
cls.cities_count)
63+
cls.switzerland = factories.CountryFactory.create(name='Switzerland')
64+
cls.switz_cities_count = 10
65+
cls.switz_cities = factories.CityFactory.create_batch(
66+
cls.switz_cities_count,
67+
country=cls.switzerland)
68+
cls.all_cities_cound = cls.cities_count + cls.switz_cities_count
69+
5970
call_command('search_index', '--rebuild', '-f')
6071

6172
def _search_by_field(self, field_name, search_term):
@@ -81,13 +92,42 @@ def _search_by_field(self, field_name, search_term):
8192
self.special_count
8293
)
8394

95+
def _search_by_nested_field(self, search_term):
96+
"""Search by field."""
97+
self.authenticate()
98+
99+
url = reverse('citydocument-list', kwargs={})
100+
data = {}
101+
102+
# Should contain 20 results
103+
response = self.client.get(url, data)
104+
self.assertEqual(response.status_code, status.HTTP_200_OK)
105+
self.assertEqual(len(response.data['results']), self.all_cities_cound)
106+
107+
# Should contain only 10 results
108+
filtered_response = self.client.get(
109+
url + '?search={}'.format(search_term),
110+
data
111+
)
112+
self.assertEqual(filtered_response.status_code, status.HTTP_200_OK)
113+
self.assertEqual(
114+
len(filtered_response.data['results']),
115+
self.switz_cities_count
116+
)
117+
84118
def test_search_by_field(self):
85119
"""Search by field."""
86120
return self._search_by_field(
87121
'summary',
88122
'photography',
89123
)
90124

125+
def test_search_by_nested_field(self):
126+
"""Search by field."""
127+
return self._search_by_nested_field(
128+
'Switzerland',
129+
)
130+
91131

92132
if __name__ == '__main__':
93133
unittest.main()

0 commit comments

Comments
 (0)