Skip to content

Commit 8c896ed

Browse files
authored
fix(segments): add an util to handle live query (#149)
1 parent 49103ae commit 8c896ed

File tree

4 files changed

+224
-8
lines changed

4 files changed

+224
-8
lines changed

django_forest/resources/utils/queryset/__init__.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from django_forest.utils.collection import Collection
21
from .filters import FiltersMixin
32
from .limit_fields import LimitFieldsMixin
43
from .pagination import PaginationMixin
54
from .scope import ScopeMixin
65
from .search import SearchMixin
6+
from .segment import SegmentMixin
77
from django_forest.resources.utils.decorators import DecoratorsMixin
88

99

10-
class QuerysetMixin(PaginationMixin, FiltersMixin, SearchMixin, ScopeMixin, DecoratorsMixin, LimitFieldsMixin):
11-
10+
class QuerysetMixin(
11+
PaginationMixin, FiltersMixin, SearchMixin, ScopeMixin, DecoratorsMixin, LimitFieldsMixin, SegmentMixin
12+
):
1213
def filter_queryset(self, queryset, Model, params, request):
1314
# Notice: first apply scope
1415
scope_filters = self.get_scope(request, Model)
@@ -34,11 +35,7 @@ def enhance_queryset(self, queryset, Model, params, request, apply_pagination=Tr
3435
queryset = queryset.order_by(params['sort'].replace('.', '__'))
3536

3637
# segment
37-
if 'segment' in params:
38-
collection = Collection._registry[Model._meta.db_table]
39-
segment = next((x for x in collection.segments if x['name'] == params['segment']), None)
40-
if segment is not None and 'where' in segment:
41-
queryset = queryset.filter(segment['where']())
38+
queryset = self.handle_segment(params, Model, queryset)
4239

4340
# limit fields
4441
queryset = self.handle_limit_fields(params, Model, queryset)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import re
2+
from django.db import connection
3+
4+
5+
class LiveQuerySegmentMixin:
6+
def handle_live_query_segment(self, live_query, Model, queryset):
7+
ids = self._get_live_query_ids(live_query)
8+
pk_field = Model._meta.pk.attname
9+
queryset = queryset.filter(**{f"{pk_field}__in": ids})
10+
return queryset
11+
12+
def _get_live_query_ids(self, live_query):
13+
self._validate_query(live_query)
14+
sql_query = "select id from (%s) as ids;" % live_query[0:live_query.find(";")]
15+
with connection.cursor() as cursor:
16+
cursor.execute(sql_query)
17+
res = cursor.fetchall()
18+
return [r[0] for r in res]
19+
20+
def _validate_query(self, query):
21+
if len(query.strip()) == 0:
22+
raise Exception("Live Query Segment: You cannot execute an empty SQL query.")
23+
24+
if ';' in query and query.find(';') < len(query.strip())-1:
25+
raise Exception("Live Query Segment: You cannot chain SQL queries.")
26+
27+
if not re.search(r'^SELECT\s.*FROM\s.*$', query, flags=re.IGNORECASE | re.MULTILINE | re.DOTALL):
28+
raise Exception("Live Query Segment: Only SELECT queries are allowed.")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .live_query_segment import LiveQuerySegmentMixin
2+
from django_forest.utils.collection import Collection
3+
4+
5+
class SegmentMixin(LiveQuerySegmentMixin):
6+
def handle_segment(self, params, Model, queryset):
7+
if 'segment' in params:
8+
collection = Collection._registry[Model._meta.db_table]
9+
segment = next((x for x in collection.segments if x['name'] == params['segment']), None)
10+
if segment is not None and 'where' in segment:
11+
queryset = queryset.filter(segment['where']())
12+
13+
# live query segment
14+
if "segmentQuery" in params:
15+
queryset = self.handle_live_query_segment(params['segmentQuery'], Model, queryset)
16+
17+
return queryset
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import copy
2+
import sys
3+
from datetime import datetime
4+
from unittest import mock
5+
6+
import pytest
7+
import pytz
8+
from django.test import TransactionTestCase
9+
from django.urls import reverse
10+
from freezegun import freeze_time
11+
12+
from django_forest.tests.fixtures.schema import test_schema
13+
from django_forest.tests.resources.views.list.test_list_scope import mocked_scope
14+
from django_forest.utils.schema import Schema
15+
from django_forest.utils.schema.json_api_schema import JsonApiSchema
16+
from django_forest.utils.date import get_timezone
17+
18+
19+
# reset forest config dir auto import
20+
from django_forest.utils.scope import ScopeManager
21+
22+
23+
@pytest.fixture()
24+
def reset_config_dir_import():
25+
for key in list(sys.modules.keys()):
26+
if key.startswith('django_forest.tests.forest'):
27+
del sys.modules[key]
28+
29+
30+
@pytest.mark.usefixtures('reset_config_dir_import')
31+
class ResourceListSmartSegmentViewTests(TransactionTestCase):
32+
fixtures = ['article.json', 'publication.json',
33+
'session.json',
34+
'question.json', 'choice.json',
35+
'place.json', 'restaurant.json',
36+
'student.json',
37+
'serial.json']
38+
39+
@pytest.fixture(autouse=True)
40+
def inject_fixtures(self, django_assert_num_queries):
41+
self._django_assert_num_queries = django_assert_num_queries
42+
43+
def setUp(self):
44+
Schema.schema = copy.deepcopy(test_schema)
45+
Schema.add_smart_features()
46+
Schema.handle_json_api_schema()
47+
self.url = reverse('django_forest:resources:list', kwargs={'resource': 'tests_question'})
48+
self.client = self.client_class(
49+
HTTP_AUTHORIZATION='Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjUiLCJlbWFpbCI6Imd1aWxsYXVtZWNAZm9yZXN0YWRtaW4uY29tIiwiZmlyc3RfbmFtZSI6Ikd1aWxsYXVtZSIsImxhc3RfbmFtZSI6IkNpc2NvIiwidGVhbSI6Ik9wZXJhdGlvbnMiLCJyZW5kZXJpbmdfaWQiOjEsImV4cCI6MTYyNTY3OTYyNi44ODYwMTh9.mHjA05yvMr99gFMuFv0SnPDCeOd2ZyMSN868V7lsjnw')
50+
51+
def tearDown(self):
52+
# reset _registry after each test
53+
JsonApiSchema._registry = {}
54+
ScopeManager.cache = {}
55+
56+
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
57+
@freeze_time(
58+
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
59+
)
60+
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
61+
def test_get(self, mocked_scope_has_expired, mocked_decode):
62+
ScopeManager.cache = {
63+
'1': {
64+
'scopes': mocked_scope,
65+
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
66+
}
67+
}
68+
response = self.client.get(self.url, {
69+
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
70+
'fields[topic]': 'name',
71+
'segmentQuery': 'select * from tests_question where id=1;',
72+
'page[number]': '1',
73+
'page[size]': '15',
74+
'timezone': 'Europe/Paris',
75+
})
76+
data = response.json()
77+
self.assertEqual(response.status_code, 200)
78+
self.assertEqual(data, {
79+
'data': [
80+
{
81+
'type': 'tests_question',
82+
'attributes': {
83+
'pub_date': '2021-06-02T13:52:53.528000+00:00',
84+
'question_text': 'what is your favorite color?',
85+
'foo': 'what is your favorite color?+foo',
86+
'bar': 'what is your favorite color?+bar'
87+
},
88+
'id': 1,
89+
'links': {
90+
'self': '/forest/tests_question/1'
91+
},
92+
'relationships': {
93+
'topic': {
94+
'data': None,
95+
'links': {
96+
'related': '/forest/tests_question/1/relationships/topic'
97+
}
98+
}
99+
},
100+
},
101+
]
102+
})
103+
104+
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
105+
@freeze_time(
106+
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
107+
)
108+
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
109+
def test_get_error_when_multiple_request(self, mocked_scope_has_expired, mocked_decode):
110+
ScopeManager.cache = {
111+
'1': {
112+
'scopes': mocked_scope,
113+
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
114+
}
115+
}
116+
response = self.client.get(self.url, {
117+
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
118+
'fields[topic]': 'name',
119+
'segmentQuery': 'select * from tests_question where id=1;select * from user_users',
120+
'page[number]': '1',
121+
'page[size]': '15',
122+
'timezone': 'Europe/Paris',
123+
})
124+
data = response.json()
125+
self.assertEqual(response.status_code, 400)
126+
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot chain SQL queries.'}]})
127+
128+
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
129+
@freeze_time(
130+
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
131+
)
132+
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
133+
def test_get_error_when_sql_is_not_select(self, mocked_scope_has_expired, mocked_decode):
134+
ScopeManager.cache = {
135+
'1': {
136+
'scopes': mocked_scope,
137+
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
138+
}
139+
}
140+
response = self.client.get(self.url, {
141+
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
142+
'fields[topic]': 'name',
143+
'segmentQuery': 'insert into tests_question(id) values(999)',
144+
'page[number]': '1',
145+
'page[size]': '15',
146+
'timezone': 'Europe/Paris',
147+
})
148+
data = response.json()
149+
self.assertEqual(response.status_code, 400)
150+
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: Only SELECT queries are allowed.'}]})
151+
152+
@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1})
153+
@freeze_time(
154+
lambda: datetime(2021, 7, 8, 9, 20, 23, 582772, tzinfo=get_timezone('UTC'))
155+
)
156+
@mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False)
157+
def test_get_error_when_sql_is_empty(self, mocked_scope_has_expired, mocked_decode):
158+
ScopeManager.cache = {
159+
'1': {
160+
'scopes': mocked_scope,
161+
'fetched_at': datetime(2021, 7, 8, 9, 20, 22, 582772, tzinfo=pytz.UTC)
162+
}
163+
}
164+
response = self.client.get(self.url, {
165+
'fields[tests_question]': 'id,topic,question_text,pub_date,foo,bar',
166+
'fields[topic]': 'name',
167+
'segmentQuery': ' \n',
168+
'page[number]': '1',
169+
'page[size]': '15',
170+
'timezone': 'Europe/Paris',
171+
})
172+
data = response.json()
173+
self.assertEqual(response.status_code, 400)
174+
self.assertEqual(data, {"errors": [{'detail': 'Live Query Segment: You cannot execute an empty SQL query.'}]})

0 commit comments

Comments
 (0)