diff --git a/django_forest/resources/associations/views/csv.py b/django_forest/resources/associations/views/csv.py index f1f252b..4e7e10c 100644 --- a/django_forest/resources/associations/views/csv.py +++ b/django_forest/resources/associations/views/csv.py @@ -1,5 +1,6 @@ import logging import csv +import time from django_forest.resources.associations.utils import AssociationView from django_forest.resources.utils.csv import CsvMixin @@ -25,8 +26,11 @@ def get(self, request, pk, association_resource): queryset = getattr(self.Model.objects.get(pk=pk), association_resource).all() params = request.GET.dict() + # enhance queryset queryset = self.enhance_queryset(queryset, RelatedModel, params, request, apply_pagination=False) + for _ in queryset: # force SQL request execution + break # handle smart fields self.handle_smart_fields(queryset, RelatedModel._meta.db_table, parse_qs(params), many=True) diff --git a/django_forest/resources/utils/queryset/__init__.py b/django_forest/resources/utils/queryset/__init__.py index 3ca161c..79dd59d 100644 --- a/django_forest/resources/utils/queryset/__init__.py +++ b/django_forest/resources/utils/queryset/__init__.py @@ -1,3 +1,5 @@ +import re + from .filters import FiltersMixin from .limit_fields import LimitFieldsMixin from .pagination import PaginationMixin @@ -5,6 +7,7 @@ from .search import SearchMixin from .segment import SegmentMixin from django_forest.resources.utils.decorators import DecoratorsMixin +from django_forest.utils.schema import Schema class QuerysetMixin( @@ -26,7 +29,29 @@ def filter_queryset(self, queryset, Model, params, request): queryset = queryset.filter(method(params, Model)) return queryset + def join_relations(self, queryset, Model, params, request): + select_related = set() + + collection = Schema.get_collection(Model._meta.db_table) + relations = [ + field['field'] + for field in collection["fields"] + if field["relationship"] is not None and field["relationship"] in ["BelongsTo", "HasOne"] + ] + + # projection + for key, value in params.items(): + if re.search(r"fields\[[^\]]+\]", key): + fields_for = key.split("fields[")[1][:-1] + if fields_for in relations: + select_related.add(fields_for) + + return queryset.select_related(*select_related) + def enhance_queryset(self, queryset, Model, params, request, apply_pagination=True): + # perform inner join + queryset = self.join_relations(queryset, Model, params, request) + # scopes + filter + search queryset = self.filter_queryset(queryset, Model, params, request) diff --git a/django_forest/tests/resources/views/list/test_list_filters_date.py b/django_forest/tests/resources/views/list/test_list_filters_date.py index 2a9717c..9c776a6 100644 --- a/django_forest/tests/resources/views/list/test_list_filters_date.py +++ b/django_forest/tests/resources/views/list/test_list_filters_date.py @@ -15,7 +15,7 @@ @mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) -@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) +@mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) class ResourceListFilterDateViewTests(TransactionTestCase): fixtures = ['question.json'] @@ -90,7 +90,8 @@ def test_future(self, *args, **kwargs): 'filters': '{"field":"pub_date","operator":"future","value":null}', 'timezone': 'Europe/Paris', 'page[number]': '1', - 'page[size]': '15' + 'page[size]': '15', + 'sort': 'id' }) data = response.json() self.assertEqual(response.status_code, 200) @@ -241,7 +242,8 @@ def test_after_x_hours_ago(self, *args, **kwargs): 'filters': '{"field":"pub_date","operator":"after_x_hours_ago","value":1}', 'timezone': 'Europe/Paris', 'page[number]': '1', - 'page[size]': '15' + 'page[size]': '15', + 'sort': 'id' }) data = response.json() self.assertEqual(response.status_code, 200) diff --git a/django_forest/tests/resources/views/list/test_list_sort.py b/django_forest/tests/resources/views/list/test_list_sort.py index cf3b949..2d4911b 100644 --- a/django_forest/tests/resources/views/list/test_list_sort.py +++ b/django_forest/tests/resources/views/list/test_list_sort.py @@ -75,11 +75,12 @@ def test_get_sort(self, mocked_datetime, mocked_decode): data = response.json() self.assertEqual(response.status_code, 200) self.assertEqual(captured.captured_queries[0]['sql'], - ' '.join('''SELECT "tests_question"."id", "tests_question"."question_text", "tests_question"."pub_date", "tests_question"."topic_id" - FROM "tests_question" - ORDER BY "tests_question"."id" - DESC - LIMIT 15'''.replace('\n', ' ').split())) + ' '.join('''SELECT "tests_question"."id", "tests_question"."question_text", "tests_question"."pub_date", "tests_question"."topic_id", "tests_topic"."id", "tests_topic"."name" + FROM "tests_question" + LEFT OUTER JOIN "tests_topic" ON ("tests_question"."topic_id" = "tests_topic"."id") + ORDER BY "tests_question"."id" + DESC + LIMIT 15'''.replace('\n', ' ').split())) self.assertEqual(data, { 'data': [ { @@ -145,7 +146,7 @@ def test_get_sort(self, mocked_datetime, mocked_decode): @mock.patch('jose.jwt.decode', return_value={'id': 1, 'rendering_id': 1}) @mock.patch('django_forest.utils.scope.ScopeManager._has_cache_expired', return_value=False) def test_get_sort_related_data(self, mocked_scope_has_expired, mocked_decode): - with self._django_assert_num_queries(7) as captured: + with self._django_assert_num_queries(4) as captured: response = self.client.get(self.reverse_url, { 'fields[tests_choice]': 'id,topic,question,choice_text', 'fields[topic]': 'name', @@ -157,9 +158,89 @@ def test_get_sort_related_data(self, mocked_scope_has_expired, mocked_decode): }) self.assertEqual(response.status_code, 200) self.assertEqual(captured.captured_queries[0]['sql'], - ' '.join('''SELECT "tests_choice"."id", "tests_choice"."question_id", "tests_choice"."choice_text" + ' '.join('''SELECT "tests_choice"."id", "tests_choice"."question_id", "tests_choice"."choice_text", "tests_question"."id", "tests_question"."question_text", "tests_question"."pub_date", "tests_question"."topic_id" FROM "tests_choice" LEFT OUTER JOIN "tests_question" ON ("tests_choice"."question_id" = "tests_question"."id") ORDER BY "tests_question"."question_text" ASC LIMIT 15'''.replace('\n', ' ').split())) + data = response.json() + self.assertEqual(data, { + "data": [ + { + "type": "tests_choice", + "relationships": { + "topic": { + "links": { + "related": "/forest/tests_choice/3/relationships/topic" + }, + "data": None, + }, + "question": { + "links": { + "related": "/forest/tests_choice/3/relationships/question" + }, + "data": {"type": "tests_question", "id": "2"}, + }, + }, + "id": 3, + "attributes": {"choice_text": "good"}, + "links": {"self": "/forest/tests_choice/3"}, + }, + { + "type": "tests_choice", + "relationships": { + "topic": { + "links": { + "related": "/forest/tests_choice/1/relationships/topic" + }, + "data": None, + }, + "question": { + "links": { + "related": "/forest/tests_choice/1/relationships/question" + }, + "data": {"type": "tests_question", "id": "1"}, + }, + }, + "id": 1, + "attributes": {"choice_text": "yes"}, + "links": {"self": "/forest/tests_choice/1"}, + }, + { + "type": "tests_choice", + "relationships": { + "topic": { + "links": { + "related": "/forest/tests_choice/2/relationships/topic" + }, + "data": None, + }, + "question": { + "links": { + "related": "/forest/tests_choice/2/relationships/question" + }, + "data": {"type": "tests_question", "id": "1"}, + }, + }, + "id": 2, + "attributes": {"choice_text": "no"}, + "links": {"self": "/forest/tests_choice/2"}, + }, + ], + "included": [ + { + "type": "tests_question", + "attributes": {"question_text": "do you like chocolate?"}, + "links": {"self": "/forest/tests_question/2"}, + "id": 2, + }, + { + "type": "tests_question", + "attributes": {"question_text": "what is your favorite color?"}, + "links": {"self": "/forest/tests_question/1"}, + "id": 1, + }, + ], + } + ) \ No newline at end of file