diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 3f8b9cd..644ec8c 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -1055,3 +1055,45 @@ def test_tree_fields_optimization(self): child2_2 = next(obj for obj in results if obj.name == "2-2") assert child2_2.tree_names == ["root", "2", "2-2"] + + def test_combinator_operations(self): + """Test that union, intersection, and difference work with tree queries""" + tree = self.create_tree() + + # Test union operation + qs1 = Model.objects.with_tree_fields().filter(name__in=["root", "1"]) + qs2 = Model.objects.with_tree_fields().filter(name__in=["2", "2-1"]) + union_result = list(qs1.union(qs2)) + + # Should have 4 unique objects + assert len(union_result) == 4 + names = {obj.name for obj in union_result} + assert names == {"root", "1", "2", "2-1"} + + # Tree fields should not be available in combinator queries + assert not hasattr(union_result[0], 'tree_depth') + + # Test intersection operation + qs3 = Model.objects.with_tree_fields().filter(name__in=["1", "2", "1-1"]) + qs4 = Model.objects.with_tree_fields().filter(name__in=["2", "2-1", "2-2"]) + intersect_result = list(qs3.intersection(qs4)) + + # Should have 1 object in common + assert len(intersect_result) == 1 + assert intersect_result[0].name == "2" + + # Test difference operation + qs5 = Model.objects.with_tree_fields().filter(name__in=["1", "2", "1-1"]) + qs6 = Model.objects.with_tree_fields().filter(name__in=["2"]) + diff_result = list(qs5.difference(qs6)) + + # Should have 2 objects (1 and 1-1) + assert len(diff_result) == 2 + names = {obj.name for obj in diff_result} + assert names == {"1", "1-1"} + + # Verify that regular tree queries still work normally + regular_tree_result = list(Model.objects.with_tree_fields().filter(name="root")) + assert len(regular_tree_result) == 1 + assert hasattr(regular_tree_result[0], 'tree_depth') + assert regular_tree_result[0].tree_depth == 0 diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index 06dbc1d..e46f9db 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -43,7 +43,7 @@ def _setup_query(self): if not hasattr(self, "tree_fields"): self.tree_fields = {} - def get_compiler(self, using=None, connection=None, **kwargs): + def get_compiler(self, using=None, connection=None, elide_empty=True, **kwargs): # Copied from django/db/models/sql/query.py if using is None and connection is None: raise ValueError("Need either using or connection") @@ -52,8 +52,8 @@ def get_compiler(self, using=None, connection=None, **kwargs): # Difference: Not connection.ops.compiler, but our own compiler which # adds the CTE. - # **kwargs passes on elide_empty from Django 4.0 onwards - return TreeCompiler(self, connection, using, **kwargs) + # Pass elide_empty and other kwargs to TreeCompiler + return TreeCompiler(self, connection, using, elide_empty, **kwargs) def get_sibling_order(self): return self.sibling_order @@ -379,6 +379,13 @@ def get_rank_table(self): return rank_table_sql, rank_table_params + def _is_part_of_combinator(self): + """ + This method is no longer used as we handle combinators differently. + Kept for backward compatibility. + """ + return False + def as_sql(self, *args, **kwargs): # Try detecting if we're used in a EXISTS(1 as "a") subquery like # Django's sql.Query.exists() generates. If we detect such a query @@ -392,6 +399,12 @@ def as_sql(self, *args, **kwargs): ): return super().as_sql(*args, **kwargs) + # Check if this query is part of a combinator operation (union/intersection/difference). + # Tree fields (CTEs) are not compatible with combinator operations, so we fall back + # to regular query compilation without tree fields. + if hasattr(self.query, 'combinator') and self.query.combinator: + return super().as_sql(*args, **kwargs) + # The general idea is that if we have a summary query (e.g. .count()) # then we do not want to ask Django to add the tree fields to the query # using .query.add_extra. The way to determine whether we have a @@ -559,6 +572,107 @@ def get_converters(self, expressions): converters[i] = ([converter], expression) return converters + def get_combinator_sql(self, combinator, all): + """ + Override combinator SQL generation to handle tree fields properly. + + When using union/intersection/difference with tree queries, tree fields + (computed via CTEs) are incompatible with combinator operations. + + The solution is to ensure all queries in the combinator operation are + compiled as regular queries without tree functionality. + """ + # Convert all TreeQuery instances to regular Query instances + # and use regular SQLCompiler for all parts + + features = self.connection.features + regular_compilers = [] + + for query in self.query.combined_queries: + # Clone the query to avoid modifying the original + cloned_query = query.clone() + + # Convert TreeQuery to regular Query + if hasattr(cloned_query, '__class__') and cloned_query.__class__.__name__ == 'TreeQuery': + cloned_query.__class__ = Query + + # Clear any tree-related orderings that might cause issues + cloned_query.clear_ordering(force=True) + + # Get a regular SQLCompiler for this query + compiler = cloned_query.get_compiler(self.using, self.connection, self.elide_empty) + regular_compilers.append(compiler) + + # Handle slicing and ordering restrictions (from Django's implementation) + if not features.supports_slicing_ordering_in_compound: + for compiler in regular_compilers: + if compiler.query.is_sliced: + from django.db import DatabaseError + raise DatabaseError( + "LIMIT/OFFSET not allowed in subqueries of compound statements." + ) + if compiler.get_order_by(): + from django.db import DatabaseError + raise DatabaseError( + "ORDER BY not allowed in subqueries of compound statements." + ) + elif self.query.is_sliced and combinator == "union": + for compiler in regular_compilers: + compiler.elide_empty = False + + # Generate SQL for each part + parts = [] + for compiler in regular_compilers: + try: + part_sql, part_args = compiler.as_sql(with_col_aliases=True) + + # Handle nested combinators and subqueries (from Django's implementation) + if compiler.query.combinator: + if not features.supports_parentheses_in_compound: + part_sql = "SELECT * FROM ({})".format(part_sql) + elif ( + self.query.subquery + or not features.supports_slicing_ordering_in_compound + ): + part_sql = "({})".format(part_sql) + elif ( + self.query.subquery + and features.supports_slicing_ordering_in_compound + ): + part_sql = "({})".format(part_sql) + + parts.append((part_sql, part_args)) + except Exception: + # Handle empty results (similar to Django's EmptyResultSet handling) + if combinator == "union" or (combinator == "difference" and parts): + continue + raise + + if not parts: + from django.db.models.sql.compiler import EmptyResultSet + raise EmptyResultSet + + # Combine the parts + combinator_sql = self.connection.ops.set_operators[combinator] + if all and combinator == "union": + combinator_sql += " ALL" + + braces = "{}" + if not self.query.subquery and features.supports_slicing_ordering_in_compound: + braces = "({})" + + sql_parts, args_parts = zip(*parts) + result = [braces.format(part) for part in sql_parts] + + # Join the parts with the combinator + final_sql = (" %s " % combinator_sql).join(result) + final_params = [] + for part_args in args_parts: + final_params.extend(part_args) + + # Return in Django's expected format: [sql], params + return [final_sql], final_params + def converter(value, expression, connection, context=None): # context can be removed as soon as we only support Django>=2.0