Skip to content

Commit 49d48bd

Browse files
committed
Add geo_distance_ordering
1 parent b0f4d3e commit 49d48bd

File tree

3 files changed

+163
-28
lines changed

3 files changed

+163
-28
lines changed

src/django_elasticsearch_dsl_drf/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@
9797
# Search query param
9898
SEARCH_QUERY_PARAM = 'q'
9999

100+
# Geo distance ordering param
101+
GEO_DISTANCE_ORDERING_PARAM = 'geo_distance_ordering'
102+
100103
# ****************************************************************************
101104
# ************************ Native lookup filters/queries *********************
102105
# ****************************************************************************

src/django_elasticsearch_dsl_drf/filter_backends/ordering/geo_spatial.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
"""
44

55
from rest_framework.filters import BaseFilterBackend
6-
from rest_framework.settings import api_settings
6+
7+
8+
from ..mixins import FilterBackendMixin
9+
from ...constants import (
10+
GEO_DISTANCE_ORDERING_PARAM,
11+
)
712

813
__title__ = 'django_elasticsearch_dsl_drf.filter_backends.ordering.common'
914
__author__ = 'Artur Barseghyan <[email protected]>'
@@ -12,7 +17,7 @@
1217
__all__ = ('GeoSpatialOrderingFilterBackend',)
1318

1419

15-
class GeoSpatialOrderingFilterBackend(BaseFilterBackend):
20+
class GeoSpatialOrderingFilterBackend(BaseFilterBackend, FilterBackendMixin):
1621
"""Geo-spatial ordering filter backend for Elasticsearch.
1722
1823
Example:
@@ -40,27 +45,46 @@ class GeoSpatialOrderingFilterBackend(BaseFilterBackend):
4045
>>> }
4146
"""
4247

43-
ordering_param = api_settings.ORDERING_PARAM
44-
45-
# TODO: Either use or remove.
46-
# @classmethod
47-
# def prepare_ordering_fields(cls, view):
48-
# """Prepare ordering fields.
49-
#
50-
# :param view: View.
51-
# :type view: rest_framework.viewsets.ReadOnlyModelViewSet
52-
# :return: Ordering options.
53-
# :rtype: dict
54-
# """
55-
# ordering_fields = view.ordering_fields
56-
# for field, options in ordering_fields.items():
57-
# if options is None or isinstance(options, string_types):
58-
# ordering_fields[field] = {
59-
# 'field': options or field
60-
# }
61-
# elif 'field' not in ordering_fields[field]:
62-
# ordering_fields[field]['field'] = field
63-
# return ordering_fields
48+
ordering_param = GEO_DISTANCE_ORDERING_PARAM
49+
50+
@classmethod
51+
def get_geo_distance_params(cls, value, field):
52+
"""Get params for `geo_distance` ordering.
53+
54+
Example:
55+
56+
/api/articles/?geo_spatial_ordering=-location|45.3214|-34.3421|km|planes
57+
58+
:param value:
59+
:param field:
60+
:type value: str
61+
:type field:
62+
:return: Params to be used in `geo_distance` query.
63+
:rtype: dict
64+
"""
65+
__values = cls.split_lookup_value(value, maxsplit=3)
66+
__len_values = len(__values)
67+
68+
if __len_values < 2:
69+
return {}
70+
71+
params = {
72+
field: {
73+
'lat': __values[0],
74+
'lon': __values[1],
75+
}
76+
}
77+
78+
if __len_values > 2:
79+
params['unit'] = __values[2]
80+
else:
81+
params['unit'] = 'm'
82+
if __len_values > 3:
83+
params['distance_type'] = __values[3]
84+
else:
85+
params['distance_type'] = 'sloppy_arc'
86+
87+
return params
6488

6589
def get_ordering_query_params(self, request, view):
6690
"""Get ordering query params.
@@ -78,13 +102,16 @@ def get_ordering_query_params(self, request, view):
78102
__ordering_params = []
79103
# Remove invalid ordering query params
80104
for query_param in ordering_query_params:
81-
__key = query_param.lstrip('-')
82-
__direction = '-' if query_param.startswith('-') else ''
105+
__key, __value = FilterBackendMixin.split_lookup_value(
106+
query_param.lstrip('-'),
107+
maxsplit=1,
108+
)
109+
__direction = 'desc' if query_param.startswith('-') else 'asc'
83110
if __key in view.geo_spatial_ordering_fields:
84111
__field_name = view.ordering_fields[__key] or __key
85-
__ordering_params.append(
86-
'{}{}'.format(__direction, __field_name)
87-
)
112+
__params = self.get_geo_distance_params(__value, __field_name)
113+
__params['order'] = __direction
114+
__ordering_params.append(__params)
88115

89116
return __ordering_params
90117

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
Test geo-spatial filtering backend.
3+
"""
4+
5+
from __future__ import absolute_import
6+
7+
import unittest
8+
9+
from django.core.management import call_command
10+
11+
from nine.versions import DJANGO_GTE_1_10
12+
13+
import pytest
14+
15+
from rest_framework import status
16+
17+
import factories
18+
19+
from .base import BaseRestFrameworkTestCase
20+
21+
if DJANGO_GTE_1_10:
22+
from django.urls import reverse
23+
else:
24+
from django.core.urlresolvers import reverse
25+
26+
__title__ = 'django_elasticsearch_dsl_drf.tests.test_filtering'
27+
__author__ = 'Artur Barseghyan <[email protected]>'
28+
__copyright__ = '2017 Artur Barseghyan'
29+
__license__ = 'GPL 2.0/LGPL 2.1'
30+
__all__ = (
31+
'TestFilteringGeoSpatial',
32+
)
33+
34+
35+
@pytest.mark.django_db
36+
class TestOrderingGeoSpatial(BaseRestFrameworkTestCase):
37+
"""Test filtering geo-spatial."""
38+
39+
pytestmark = pytest.mark.django_db
40+
41+
@classmethod
42+
def setUpClass(cls):
43+
"""Set up."""
44+
cls.geo_origin = factories.PublisherFactory.create(
45+
**{
46+
'latitude': 48.8549,
47+
'longitude': 2.3000,
48+
}
49+
)
50+
51+
cls.geo_in_count = 5
52+
cls.unit = 'km'
53+
cls.algo = 'plane'
54+
cls.geo_in = []
55+
for index in range(cls.geo_in_count):
56+
__publisher = factories.PublisherFactory.create(
57+
**{
58+
'latitude': 48.8570 + index,
59+
'longitude': 2.3005,
60+
}
61+
)
62+
cls.geo_in.append(__publisher)
63+
64+
cls.base_publisher_url = reverse('publisherdocument-list', kwargs={})
65+
call_command('search_index', '--rebuild', '-f')
66+
67+
@pytest.mark.webtest
68+
def test_field_filter_geo_distance(self):
69+
"""Field filter term.
70+
71+
Example:
72+
73+
http://localhost:8000
74+
/api/publisher/?geo_distance_ordering=location|48.8549|2.3000|km|plane
75+
"""
76+
self.authenticate()
77+
78+
__params = 'location|{}|{}|{}|{}'.format(
79+
self.geo_origin.latitude,
80+
self.geo_origin.longitude,
81+
self.unit,
82+
self.algo
83+
)
84+
85+
url = self.base_publisher_url[:] + '?{}={}'.format(
86+
'geo_distance_ordering',
87+
__params
88+
)
89+
90+
data = {}
91+
response = self.client.get(url, data)
92+
self.assertEqual(response.status_code, status.HTTP_200_OK)
93+
# Should contain only 6 results
94+
self.assertEqual(len(response.data['results']), self.geo_in_count + 1)
95+
item_count = len(response.data['results'])
96+
for counter, item in enumerate(response.data['results']):
97+
if (counter > 1) and (counter < item_count + 1):
98+
self.assertLess(
99+
response.data['results'][counter-1]['location']['lat'],
100+
response.data['results'][counter]['location']['lat']
101+
)
102+
103+
104+
if __name__ == '__main__':
105+
unittest.main()

0 commit comments

Comments
 (0)