Skip to content

Commit e5dbc49

Browse files
committed
feat: support prewhere, finish test for django 4.2
1 parent e351a30 commit e5dbc49

File tree

8 files changed

+199
-19
lines changed

8 files changed

+199
-19
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Read [Documentation](https://github.com/jayvynl/django-clickhouse-backend/blob/m
2828
- Support creating test database and table, working with django TestCase and pytest-django.
2929
- Support most clickhouse data types.
3030
- Support [SETTINGS in SELECT Query](https://clickhouse.com/docs/en/sql-reference/statements/select/#settings-in-select-query).
31+
- Support [PREWHERE clause](https://clickhouse.com/docs/en/sql-reference/statements/select/prewhere).
3132

3233
**Notes:**
3334

clickhouse_backend/models/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def get_queryset(self):
2626
model=self.model, query=Query(self.model), using=self._db, hints=self._hints
2727
)
2828

29+
def prewhere(self, *args, **kwargs):
30+
return self.get_queryset().prewhere(*args, **kwargs)
31+
2932

3033
class ClickhouseModel(models.Model):
3134
objects = ClickhouseManager()

clickhouse_backend/models/query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from django.db.models import query
1+
from django.db.models import Q, query
22

3-
from .sql import Query
3+
from clickhouse_backend.models import sql
44

55

66
class QuerySet(query.QuerySet):
@@ -14,7 +14,7 @@ def explain(self, *, format=None, type=None, **settings):
1414

1515
def settings(self, **kwargs):
1616
clone = self._chain()
17-
if isinstance(clone.query, Query):
17+
if isinstance(clone.query, sql.Query):
1818
clone.query.setting_info.update(kwargs)
1919
return clone
2020

@@ -27,5 +27,5 @@ def prewhere(self, *args, **kwargs):
2727
if (args or kwargs) and self.query.is_sliced:
2828
raise TypeError("Cannot prewhere a query once a slice has been taken.")
2929
clone = self._chain()
30-
clone._query.add_prewhere(query.Q(*args, **kwargs))
30+
clone._query.add_prewhere(Q(*args, **kwargs))
3131
return clone

clickhouse_backend/models/sql/compiler.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from clickhouse_backend import compat
99
from clickhouse_backend.idworker import id_worker
1010
from clickhouse_backend.models import engines
11+
from clickhouse_backend.models.sql import Query
1112

1213
if compat.dj_ge42:
1314
from django.core.exceptions import FullResultSet
@@ -75,13 +76,24 @@ def pre_sql_setup(self, with_col_aliases=False):
7576
) = self.query.where.split_having_qualify(
7677
must_group_by=self.query.group_by is not None
7778
)
78-
(
79-
self.prewhere,
80-
prehaving,
81-
prequalify,
82-
) = self.query.prewhere.split_having_qualify(
83-
must_group_by=self.query.group_by is not None
84-
)
79+
if isinstance(self.query, Query):
80+
(
81+
self.prewhere,
82+
prehaving,
83+
prequalify,
84+
) = self.query.prewhere.split_having_qualify(
85+
must_group_by=self.query.group_by is not None
86+
)
87+
else:
88+
(
89+
self.prewhere,
90+
prehaving,
91+
prequalify,
92+
) = (
93+
None,
94+
None,
95+
None,
96+
)
8597
# Check before ClickHouse complain.
8698
# DB::Exception: Window function is found in PREWHERE in query. (ILLEGAL_AGGREGATION)
8799
if prequalify:

clickhouse_backend/models/sql/query.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.db import router
44
from django.db.models.sql import query, subqueries
55
from django.db.models.sql.constants import INNER
6+
from django.db.models.sql.datastructures import BaseTable, Join
67
from django.db.models.sql.where import AND
78

89
from clickhouse_backend import compat
@@ -26,6 +27,7 @@ def sql_with_params(self):
2627
def clone(self):
2728
obj = super().clone()
2829
obj.setting_info = self.setting_info.copy()
30+
obj.prewhere = self.prewhere.clone()
2931
return obj
3032

3133
def explain(self, using, format=None, type=None, **settings):
@@ -36,15 +38,8 @@ def explain(self, using, format=None, type=None, **settings):
3638

3739
def add_prewhere(self, q_object):
3840
"""
39-
A preprocessor for the internal _add_q(). Responsible for doing final
40-
join promotion.
41+
refer add_q
4142
"""
42-
# For join promotion this case is doing an AND for the added q_object
43-
# and existing conditions. So, any existing inner join forces the join
44-
# type to remain inner. Existing outer joins can however be demoted.
45-
# (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
46-
# rel_a doesn't produce any rows, then the whole condition must fail.
47-
# So, demotion is OK.
4843
existing_inner = {
4944
a for a in self.alias_map if self.alias_map[a].join_type == INNER
5045
}
@@ -59,6 +54,82 @@ def add_prewhere(self, q_object):
5954
def is_sliced(self):
6055
return self.low_mark != 0 or self.high_mark is not None
6156

57+
def resolve_expression(self, query, *args, **kwargs):
58+
clone = self.clone()
59+
# Subqueries need to use a different set of aliases than the outer query.
60+
clone.bump_prefix(query)
61+
clone.subquery = True
62+
clone.where.resolve_expression(query, *args, **kwargs)
63+
clone.prewhere.resolve_expression(query, *args, **kwargs)
64+
# Resolve combined queries.
65+
if clone.combinator:
66+
clone.combined_queries = tuple(
67+
[
68+
combined_query.resolve_expression(query, *args, **kwargs)
69+
for combined_query in clone.combined_queries
70+
]
71+
)
72+
for key, value in clone.annotations.items():
73+
resolved = value.resolve_expression(query, *args, **kwargs)
74+
if hasattr(resolved, "external_aliases"):
75+
resolved.external_aliases.update(clone.external_aliases)
76+
clone.annotations[key] = resolved
77+
# Outer query's aliases are considered external.
78+
for alias, table in query.alias_map.items():
79+
clone.external_aliases[alias] = (
80+
isinstance(table, Join)
81+
and table.join_field.related_model._meta.db_table != alias
82+
) or (
83+
isinstance(table, BaseTable) and table.table_name != table.table_alias
84+
)
85+
return clone
86+
87+
def change_aliases(self, change_map):
88+
"""
89+
Change the aliases in change_map (which maps old-alias -> new-alias),
90+
relabelling any references to them in select columns and the where
91+
clause.
92+
"""
93+
# If keys and values of change_map were to intersect, an alias might be
94+
# updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
95+
# on their order in change_map.
96+
assert set(change_map).isdisjoint(change_map.values())
97+
98+
# 1. Update references in "select" (normal columns plus aliases),
99+
# "group by" and "where".
100+
self.where.relabel_aliases(change_map)
101+
self.prewhere.relabel_aliases(change_map)
102+
if isinstance(self.group_by, tuple):
103+
self.group_by = tuple(
104+
[col.relabeled_clone(change_map) for col in self.group_by]
105+
)
106+
self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
107+
self.annotations = self.annotations and {
108+
key: col.relabeled_clone(change_map)
109+
for key, col in self.annotations.items()
110+
}
111+
112+
# 2. Rename the alias in the internal table/alias datastructures.
113+
for old_alias, new_alias in change_map.items():
114+
if old_alias not in self.alias_map:
115+
continue
116+
alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
117+
self.alias_map[new_alias] = alias_data
118+
self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
119+
del self.alias_refcount[old_alias]
120+
del self.alias_map[old_alias]
121+
122+
table_aliases = self.table_map[alias_data.table_name]
123+
for pos, alias in enumerate(table_aliases):
124+
if alias == old_alias:
125+
table_aliases[pos] = new_alias
126+
break
127+
self.external_aliases = {
128+
# Table is aliased or it's being changed and thus is aliased.
129+
change_map.get(alias, alias): (aliased or alias in change_map)
130+
for alias, aliased in self.external_aliases.items()
131+
}
132+
62133

63134
def clone_decorator(cls):
64135
old_clone = cls.clone

tests/clickhouse_queries/__init__.py

Whitespace-only changes.

tests/clickhouse_queries/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from django.db.models import CASCADE, ForeignKey
2+
3+
from clickhouse_backend import models
4+
5+
6+
class Author(models.ClickhouseModel):
7+
name = models.StringField(max_length=10)
8+
num = models.UInt32Field()
9+
10+
11+
class Book(models.ClickhouseModel):
12+
author = ForeignKey(Author, on_delete=CASCADE, related_name="books")
13+
name = models.StringField(max_length=10)
14+
15+
16+
class Article(models.ClickhouseModel):
17+
title = models.StringField(max_length=10)
18+
book = models.Int64Field()

tests/clickhouse_queries/tests.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from django.db import NotSupportedError
2+
from django.db.models import Count, Window
3+
from django.db.models.functions import Rank
4+
from django.test import TestCase
5+
6+
from . import models
7+
8+
9+
class QueriesTests(TestCase):
10+
@classmethod
11+
def setUpTestData(cls):
12+
cls.a1, cls.a2 = models.Author.objects.bulk_create(
13+
[models.Author(name="a1", num=1001), models.Author(name="a2", num=2002)]
14+
)
15+
cls.b1, cls.b2, cls.b3, cls.b4 = models.Book.objects.bulk_create(
16+
[
17+
models.Book(name="b1", author=cls.a1),
18+
models.Book(name="b2", author=cls.a1),
19+
models.Book(name="b3", author=cls.a2),
20+
models.Book(name="b4", author=cls.a2),
21+
]
22+
)
23+
models.Article.objects.bulk_create(
24+
[
25+
models.Article(title="t1", book=cls.b1.id),
26+
models.Article(title="t2", book=cls.b2.id),
27+
]
28+
)
29+
30+
def test_prewhere(self):
31+
qs = models.Author.objects.prewhere(name="a1")
32+
self.assertIn("PREWHERE", str(qs.query))
33+
self.assertEqual(qs[0].name, "a1")
34+
35+
def test_prewhere_fk(self):
36+
self.assertQuerySetEqual(
37+
models.Book.objects.filter(author__name=self.a1.name)
38+
.prewhere(author_id=self.a1.id)
39+
.order_by("name"),
40+
[self.b1, self.b2],
41+
)
42+
43+
# clickhouse backend will generate suitable query, but clickhouse will raise exception.
44+
# clickhouse 23.11
45+
# DB::Exception: Missing columns: 'clickhouse_queries_article.book' while processing query: 'SELECT name FROM clickhouse_queries_book AS U0 PREWHERE id = clickhouse_queries_article.book', required columns: 'name' 'id' 'clickhouse_queries_article.book', maybe you meant: 'name' or 'id': While processing (SELECT U0.name FROM clickhouse_queries_book AS U0 PREWHERE U0.id = clickhouse_queries_article.book) AS book_name.
46+
# clickhouse 24.6
47+
# DB::Exception: Resolve identifier 'clickhouse_queries_article.book' from parent scope only supported for constants and CTE. Actual test_default.clickhouse_queries_article.book node type COLUMN. In scope (SELECT U0.name FROM clickhouse_queries_book AS U0 PREWHERE U0.id = clickhouse_queries_article.book) AS book_name.
48+
# def test_prewhere_subquery(self):
49+
# a = models.Article.objects.annotate(
50+
# book_name=Subquery(
51+
# models.Book.objects.prewhere(id=OuterRef("book")).values("name")
52+
# )
53+
# ).get(title="t1")
54+
# self.assertEqual(a.book_name, self.b1.name)
55+
56+
def test_prewhere_agg(self):
57+
with self.assertRaisesMessage(
58+
NotSupportedError,
59+
"Aggregate function is disallowed in the prewhere clause.",
60+
):
61+
list(
62+
models.Author.objects.annotate(count=Count("books")).prewhere(
63+
count__gt=0
64+
)
65+
)
66+
67+
def test_prewhere_window(self):
68+
with self.assertRaisesMessage(
69+
NotSupportedError, "Window function is disallowed in the prewhere clause."
70+
):
71+
list(
72+
models.Book.objects.annotate(
73+
rank=Window(Rank(), partition_by="author", order_by="name")
74+
).prewhere(rank__gt=1)
75+
)

0 commit comments

Comments
 (0)