diff --git a/django_redshift_backend/base.py b/django_redshift_backend/base.py index a88c30e..69a8ea4 100644 --- a/django_redshift_backend/base.py +++ b/django_redshift_backend/base.py @@ -44,7 +44,7 @@ class DatabaseFeatures(BasePGDatabaseFeatures): can_return_rows_from_bulk_insert = False # new name since django-3.0 has_select_for_update = False supports_column_check_constraints = False - can_distinct_on_fields = False + can_distinct_on_fields = True allows_group_by_selected_pks = False has_native_uuid_field = False supports_aggregate_filter_clause = False @@ -58,6 +58,8 @@ class DatabaseFeatures(BasePGDatabaseFeatures): class DatabaseOperations(BasePGDatabaseOperations): + compiler_module = "django_redshift_backend.compiler" + def last_insert_id(self, cursor, table_name, pk_name): """ Amazon Redshift doesn't support RETURNING, so this method @@ -104,14 +106,19 @@ def convert_uuidfield_value(self, value, expression, connection): value = uuid.UUID(value) return value - def distinct_sql(self, fields, *args): + def distinct_sql(self, fields, params, order_by=None): if fields: - # https://github.com/jazzband/django-redshift-backend/issues/14 - # Redshift doesn't support DISTINCT ON - raise NotSupportedError( - 'DISTINCT ON fields is not supported by this database backend' - ) - return super(DatabaseOperations, self).distinct_sql(fields, *args) + distinct_on = ", ".join(fields) + result = f"ROW_NUMBER() OVER (PARTITION BY {distinct_on}" + if order_by: + ordering = [] + for _, (o_sql, _, _) in order_by: + ordering.append(o_sql) + ordering = ", ".join(ordering) + result += f" ORDER BY {ordering}" + result += ") AS row_number," + return [result], [] + return ["DISTINCT"], [] def _get_type_default(field): diff --git a/django_redshift_backend/compiler.py b/django_redshift_backend/compiler.py new file mode 100644 index 0000000..1a01410 --- /dev/null +++ b/django_redshift_backend/compiler.py @@ -0,0 +1,271 @@ +import warnings + +from django.db import NotSupportedError +from django.db.models.sql.compiler import ( + SQLAggregateCompiler, + SQLCompiler as BaseSQLCompiler, + SQLDeleteCompiler, + SQLInsertCompiler, + SQLUpdateCompiler, +) +from django.db.transaction import TransactionManagementError +from django.db.utils import NotSupportedError +from django.utils.deprecation import RemovedInDjango31Warning + + +FORCE = object() + + +class SQLCompiler(BaseSQLCompiler): + def as_sql(self, with_limits=True, with_col_aliases=False): + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + If 'with_limits' is False, any limit/offset information is not included + in the query. + """ + refcounts_before = self.query.alias_refcount.copy() + try: + extra_select, order_by, group_by = self.pre_sql_setup() + for_update_part = None + # Is a LIMIT/OFFSET clause needed? + with_limit_offset = with_limits and ( + self.query.high_mark is not None or self.query.low_mark + ) + combinator = self.query.combinator + features = self.connection.features + if combinator: + if not getattr( + features, 'supports_select_{}'.format(combinator) + ): + raise NotSupportedError( + '{} is not supported on this database backend.'.format( + combinator + ) + ) + result, params = self.get_combinator_sql( + combinator, self.query.combinator_all + ) + else: + distinct_fields, distinct_params = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' + # (see docstring of get_from_clause() for details). + from_, f_params = self.get_from_clause() + where, w_params = ( + self.compile(self.where) + if self.where is not None + else ("", []) + ) + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + result = ['SELECT'] + params = [] + + if self.query.distinct: + ( + distinct_result, + distinct_params, + ) = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + order_by, + ) + result += distinct_result + params += distinct_params + + out_cols = [] + col_idx = 1 + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = '%s AS %s' % ( + s_sql, + self.connection.ops.quote_name(alias), + ) + elif with_col_aliases: + s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx) + col_idx += 1 + params.extend(s_params) + out_cols.append(s_sql) + + result += [', '.join(out_cols), 'FROM', *from_] + params.extend(f_params) + + if ( + self.query.select_for_update + and features.has_select_for_update + ): + if self.connection.get_autocommit(): + raise TransactionManagementError( + 'select_for_update cannot be used outside of a transaction.' + ) + + if ( + with_limit_offset + and not features.supports_select_for_update_with_limit + ): + raise NotSupportedError( + 'LIMIT/OFFSET is not supported with ' + 'select_for_update on this database backend.' + ) + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + of = self.query.select_for_update_of + # If it's a NOWAIT/SKIP LOCKED/OF query but the backend + # doesn't support it, raise NotSupportedError to prevent a + # possible deadlock. + if nowait and not features.has_select_for_update_nowait: + raise NotSupportedError( + 'NOWAIT is not supported on this database backend.' + ) + elif ( + skip_locked + and not features.has_select_for_update_skip_locked + ): + raise NotSupportedError( + 'SKIP LOCKED is not supported on this database backend.' + ) + elif of and not features.has_select_for_update_of: + raise NotSupportedError( + 'FOR UPDATE OF is not supported on this database backend.' + ) + for_update_part = self.connection.ops.for_update_sql( + nowait=nowait, + skip_locked=skip_locked, + of=self.get_select_for_update_of_arguments(), + ) + + if for_update_part and features.for_update_after_from: + result.append(for_update_part) + + if where: + result.append('WHERE %s' % where) + params.extend(w_params) + + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError( + 'annotate() + distinct(fields) is not implemented.' + ) + order_by = ( + order_by or self.connection.ops.force_no_ordering() + ) + result.append('GROUP BY %s' % ', '.join(grouping)) + if self._meta_ordering: + # When the deprecation ends, replace with: + # order_by = None + warnings.warn( + "%s QuerySet won't use Meta.ordering in Django 3.1. " + "Add .order_by(%s) to retain the current query." + % ( + self.query.model.__name__, + ', '.join( + repr(f) for f in self._meta_ordering + ), + ), + RemovedInDjango31Warning, + stacklevel=4, + ) + if having: + result.append('HAVING %s' % having) + params.extend(h_params) + + if self.query.explain_query: + result.insert( + 0, + self.connection.ops.explain_query_prefix( + self.query.explain_format, **self.query.explain_options + ), + ) + + if order_by: + ordering = [] + for _, (o_sql, o_params, _) in order_by: + ordering.append(o_sql) + params.extend(o_params) + result.append('ORDER BY %s' % ', '.join(ordering)) + + if with_limit_offset: + result.append( + self.connection.ops.limit_offset_sql( + self.query.low_mark, self.query.high_mark + ) + ) + + if for_update_part and not features.for_update_after_from: + result.append(for_update_part) + + if self.query.distinct_fields: + pre_result = " ".join(result) + tb_out_cols = [f'"tb".{col.split(".")[1]}' for col in out_cols] + if with_col_aliases: + tb_out_cols = [f'"tb"."Col{idx + 1}"' for idx in range(len(tb_out_cols))] + sql = f'SELECT {", ".join(tb_out_cols)} FROM ({pre_result}) AS "tb" WHERE "tb"."row_number" = 1' + return sql, tuple(params) + + if self.query.subquery and extra_select: + # If the query is used as a subquery, the extra selects would + # result in more columns than the left-hand side expression is + # expecting. This can happen when a subquery uses a combination + # of order_by() and distinct(), forcing the ordering expressions + # to be selected as well. Wrap the query in another subquery + # to exclude extraneous selects. + sub_selects = [] + sub_params = [] + for index, (select, _, alias) in enumerate( + self.select, start=1 + ): + if not alias and with_col_aliases: + alias = 'col%d' % index + if alias: + sub_selects.append( + "%s.%s" + % ( + self.connection.ops.quote_name('subquery'), + self.connection.ops.quote_name(alias), + ) + ) + else: + select_clone = select.relabeled_clone( + {select.alias: 'subquery'} + ) + subselect, subparams = select_clone.as_sql( + self, self.connection + ) + sub_selects.append(subselect) + sub_params.extend(subparams) + return 'SELECT %s FROM (%s) subquery' % ( + ', '.join(sub_selects), + ' '.join(result), + ), tuple(sub_params + params) + + return ' '.join(result), tuple(params) + finally: + # Finally do cleanup - get rid of the joins we created above. + self.query.reset_refcounts(refcounts_before) + + +class SQLAggregateCompiler(SQLCompiler): + def as_sql(self): + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + """ + sql, params = [], [] + for annotation in self.query.annotation_select.values(): + ann_sql, ann_params = self.compile(annotation, select_format=FORCE) + sql.append(ann_sql) + params.extend(ann_params) + self.col_count = len(self.query.annotation_select) + sql = ', '.join(sql) + params = tuple(params) + + sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery) + params = params + self.query.sub_params + return sql, params diff --git a/tests/test_redshift_backend.py b/tests/test_redshift_backend.py index 1da0148..bf18709 100644 --- a/tests/test_redshift_backend.py +++ b/tests/test_redshift_backend.py @@ -7,6 +7,7 @@ from django.db import connections from django.db.utils import NotSupportedError from django.core.management.color import no_style +from django.utils.timezone import now import pytest @@ -79,6 +80,34 @@ def test_load_redshift_backend(self): FROM "testapp_testmodel" ''') +expected_dml_distinct_fields = norm_sql( + u''' + SELECT + "tb"."id", + "tb"."ctime", + "tb"."text", + "tb"."uuid" + FROM ( + SELECT + ROW_NUMBER() OVER ( + PARTITION BY + "testapp_testmodel"."uuid" + ORDER BY + "testapp_testmodel"."uuid" ASC, + "testapp_testmodel"."ctime" DESC + ) AS row_number, + "testapp_testmodel"."id", + "testapp_testmodel"."ctime", + "testapp_testmodel"."text", + "testapp_testmodel"."uuid" + FROM "testapp_testmodel" + WHERE ("testapp_testmodel"."ctime" <= %s AND "testapp_testmodel"."text" = %s) + ORDER BY + "testapp_testmodel"."uuid" ASC, + "testapp_testmodel"."ctime" DESC + ) AS "tb" + WHERE "tb"."row_number" = 1 +''') class ModelTest(unittest.TestCase): @@ -130,10 +159,18 @@ def test_distinct(self): def test_distinct_with_fields(self): from testapp.models import TestModel - query = TestModel.objects.distinct('text').query + query = ( + TestModel.objects.filter( + text='test', + ctime__lte=now() + ) + .order_by('uuid', '-ctime') + .distinct('uuid') + .query + ) compiler = query.get_compiler(using='default') - with self.assertRaises(NotSupportedError): - compiler.as_sql() + sql = norm_sql(compiler.as_sql()[0]) + self.assertEqual(sql, expected_dml_distinct_fields) class MigrationTest(unittest.TestCase):