Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,45 @@ def test_tree_fields_optimization(self):

child2_2 = next(obj for obj in results if obj.name == "2-2")
assert child2_2.tree_names == ["root", "2", "2-2"]

def test_combinator_operations(self):
"""Test that union, intersection, and difference work with tree queries"""
tree = self.create_tree()

# Test union operation
qs1 = Model.objects.with_tree_fields().filter(name__in=["root", "1"])
qs2 = Model.objects.with_tree_fields().filter(name__in=["2", "2-1"])
union_result = list(qs1.union(qs2))

# Should have 4 unique objects
assert len(union_result) == 4
names = {obj.name for obj in union_result}
assert names == {"root", "1", "2", "2-1"}

# Tree fields should not be available in combinator queries
assert not hasattr(union_result[0], 'tree_depth')

# Test intersection operation
qs3 = Model.objects.with_tree_fields().filter(name__in=["1", "2", "1-1"])
qs4 = Model.objects.with_tree_fields().filter(name__in=["2", "2-1", "2-2"])
intersect_result = list(qs3.intersection(qs4))

# Should have 1 object in common
assert len(intersect_result) == 1
assert intersect_result[0].name == "2"

# Test difference operation
qs5 = Model.objects.with_tree_fields().filter(name__in=["1", "2", "1-1"])
qs6 = Model.objects.with_tree_fields().filter(name__in=["2"])
diff_result = list(qs5.difference(qs6))

# Should have 2 objects (1 and 1-1)
assert len(diff_result) == 2
names = {obj.name for obj in diff_result}
assert names == {"1", "1-1"}

# Verify that regular tree queries still work normally
regular_tree_result = list(Model.objects.with_tree_fields().filter(name="root"))
assert len(regular_tree_result) == 1
assert hasattr(regular_tree_result[0], 'tree_depth')
assert regular_tree_result[0].tree_depth == 0
120 changes: 117 additions & 3 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _setup_query(self):
if not hasattr(self, "tree_fields"):
self.tree_fields = {}

def get_compiler(self, using=None, connection=None, **kwargs):
def get_compiler(self, using=None, connection=None, elide_empty=True, **kwargs):
# Copied from django/db/models/sql/query.py
if using is None and connection is None:
raise ValueError("Need either using or connection")
Expand All @@ -52,8 +52,8 @@ def get_compiler(self, using=None, connection=None, **kwargs):
# Difference: Not connection.ops.compiler, but our own compiler which
# adds the CTE.

# **kwargs passes on elide_empty from Django 4.0 onwards
return TreeCompiler(self, connection, using, **kwargs)
# Pass elide_empty and other kwargs to TreeCompiler
return TreeCompiler(self, connection, using, elide_empty, **kwargs)

def get_sibling_order(self):
return self.sibling_order
Expand Down Expand Up @@ -379,6 +379,13 @@ def get_rank_table(self):

return rank_table_sql, rank_table_params

def _is_part_of_combinator(self):
"""
This method is no longer used as we handle combinators differently.
Kept for backward compatibility.
"""
return False

def as_sql(self, *args, **kwargs):
# Try detecting if we're used in a EXISTS(1 as "a") subquery like
# Django's sql.Query.exists() generates. If we detect such a query
Expand All @@ -392,6 +399,12 @@ def as_sql(self, *args, **kwargs):
):
return super().as_sql(*args, **kwargs)

# Check if this query is part of a combinator operation (union/intersection/difference).
# Tree fields (CTEs) are not compatible with combinator operations, so we fall back
# to regular query compilation without tree fields.
if hasattr(self.query, 'combinator') and self.query.combinator:
return super().as_sql(*args, **kwargs)

# The general idea is that if we have a summary query (e.g. .count())
# then we do not want to ask Django to add the tree fields to the query
# using .query.add_extra. The way to determine whether we have a
Expand Down Expand Up @@ -559,6 +572,107 @@ def get_converters(self, expressions):
converters[i] = ([converter], expression)
return converters

def get_combinator_sql(self, combinator, all):
"""
Override combinator SQL generation to handle tree fields properly.

When using union/intersection/difference with tree queries, tree fields
(computed via CTEs) are incompatible with combinator operations.

The solution is to ensure all queries in the combinator operation are
compiled as regular queries without tree functionality.
"""
# Convert all TreeQuery instances to regular Query instances
# and use regular SQLCompiler for all parts

features = self.connection.features
regular_compilers = []

for query in self.query.combined_queries:
# Clone the query to avoid modifying the original
cloned_query = query.clone()

# Convert TreeQuery to regular Query
if hasattr(cloned_query, '__class__') and cloned_query.__class__.__name__ == 'TreeQuery':
cloned_query.__class__ = Query

# Clear any tree-related orderings that might cause issues
cloned_query.clear_ordering(force=True)

# Get a regular SQLCompiler for this query
compiler = cloned_query.get_compiler(self.using, self.connection, self.elide_empty)
regular_compilers.append(compiler)

# Handle slicing and ordering restrictions (from Django's implementation)
if not features.supports_slicing_ordering_in_compound:
for compiler in regular_compilers:
if compiler.query.is_sliced:
from django.db import DatabaseError
raise DatabaseError(
"LIMIT/OFFSET not allowed in subqueries of compound statements."
)
if compiler.get_order_by():
from django.db import DatabaseError
raise DatabaseError(
"ORDER BY not allowed in subqueries of compound statements."
)
elif self.query.is_sliced and combinator == "union":
for compiler in regular_compilers:
compiler.elide_empty = False

# Generate SQL for each part
parts = []
for compiler in regular_compilers:
try:
part_sql, part_args = compiler.as_sql(with_col_aliases=True)

# Handle nested combinators and subqueries (from Django's implementation)
if compiler.query.combinator:
if not features.supports_parentheses_in_compound:
part_sql = "SELECT * FROM ({})".format(part_sql)
elif (
self.query.subquery
or not features.supports_slicing_ordering_in_compound
):
part_sql = "({})".format(part_sql)
elif (
self.query.subquery
and features.supports_slicing_ordering_in_compound
):
part_sql = "({})".format(part_sql)

parts.append((part_sql, part_args))
except Exception:
# Handle empty results (similar to Django's EmptyResultSet handling)
if combinator == "union" or (combinator == "difference" and parts):
continue
raise

if not parts:
from django.db.models.sql.compiler import EmptyResultSet
raise EmptyResultSet

# Combine the parts
combinator_sql = self.connection.ops.set_operators[combinator]
if all and combinator == "union":
combinator_sql += " ALL"

braces = "{}"
if not self.query.subquery and features.supports_slicing_ordering_in_compound:
braces = "({})"

sql_parts, args_parts = zip(*parts)
result = [braces.format(part) for part in sql_parts]

# Join the parts with the combinator
final_sql = (" %s " % combinator_sql).join(result)
final_params = []
for part_args in args_parts:
final_params.extend(part_args)

# Return in Django's expected format: [sql], params
return [final_sql], final_params


def converter(value, expression, connection, context=None):
# context can be removed as soon as we only support Django>=2.0
Expand Down