From 2f621193c2e019e0995d043a12f577aaa3769358 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:07:05 +0000 Subject: [PATCH 1/3] Initial plan From b8742fb91a316ead13e8772a161acf54b0e4fcc8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:23:26 +0000 Subject: [PATCH 2/3] WIP: Fix union/intersection/difference operations with tree fields - partially working Co-authored-by: matthiask <2627+matthiask@users.noreply.github.com> --- tree_queries/compiler.py | 154 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 3 deletions(-) diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index 06dbc1d..37dcc49 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -43,7 +43,7 @@ def _setup_query(self): if not hasattr(self, "tree_fields"): self.tree_fields = {} - def get_compiler(self, using=None, connection=None, **kwargs): + def get_compiler(self, using=None, connection=None, elide_empty=True, **kwargs): # Copied from django/db/models/sql/query.py if using is None and connection is None: raise ValueError("Need either using or connection") @@ -52,8 +52,8 @@ def get_compiler(self, using=None, connection=None, **kwargs): # Difference: Not connection.ops.compiler, but our own compiler which # adds the CTE. - # **kwargs passes on elide_empty from Django 4.0 onwards - return TreeCompiler(self, connection, using, **kwargs) + # Pass elide_empty and other kwargs to TreeCompiler + return TreeCompiler(self, connection, using, elide_empty, **kwargs) def get_sibling_order(self): return self.sibling_order @@ -379,6 +379,13 @@ def get_rank_table(self): return rank_table_sql, rank_table_params + def _is_part_of_combinator(self): + """ + This method is no longer used as we handle combinators differently. + Kept for backward compatibility. + """ + return False + def as_sql(self, *args, **kwargs): # Try detecting if we're used in a EXISTS(1 as "a") subquery like # Django's sql.Query.exists() generates. If we detect such a query @@ -392,6 +399,31 @@ def as_sql(self, *args, **kwargs): ): return super().as_sql(*args, **kwargs) + # Check if this query is part of a combinator operation (union/intersection/difference) + # If so, we need to skip tree field generation entirely and use regular SQL compilation + if hasattr(self.query, 'combinator') and self.query.combinator: + # For combinator queries, we also need to clean up any tree-related extra fields + # that might have been added previously, and clear tree-related ordering + if hasattr(self.query, 'extra_select'): + # Remove tree fields from extra_select + tree_fields = {'tree_depth', 'tree_path', 'tree_ordering'} + for field_name in list(self.query.extra_select.keys()): + if field_name in tree_fields: + del self.query.extra_select[field_name] + + # Clear tree-related extra tables + if hasattr(self.query, 'extra_tables') and '__tree' in self.query.extra_tables: + self.query.extra_tables = tuple(t for t in self.query.extra_tables if t != '__tree') + + # Clear tree-related extra where clauses + if hasattr(self.query, 'extra_where'): + self.query.extra_where = [w for w in self.query.extra_where if '__tree' not in w] + + # Clear any tree-related ordering + self.query.clear_ordering(force=True) + + return super().as_sql(*args, **kwargs) + # The general idea is that if we have a summary query (e.g. .count()) # then we do not want to ask Django to add the tree fields to the query # using .query.add_extra. The way to determine whether we have a @@ -559,6 +591,122 @@ def get_converters(self, expressions): converters[i] = ([converter], expression) return converters + def get_combinator_sql(self, combinator, all): + """ + Override combinator SQL generation to handle tree fields properly. + + When using union/intersection/difference with tree queries, we need to + handle tree fields specially since they are computed fields that don't + exist as real model fields. + + The approach is to remove tree fields from the normalization process. + """ + # Tree field names that should be excluded from combinator operations + TREE_FIELD_NAMES = {'tree_depth', 'tree_path', 'tree_ordering'} + + # Get the current values that would be applied to sub-queries + extra_select_keys = set(self.query.extra_select.keys()) + values_select = self.query.values_select + annotation_select_keys = set(self.query.annotation_select.keys()) + + # Check if any tree fields are present + tree_fields_present = ( + TREE_FIELD_NAMES & extra_select_keys or + any(name in TREE_FIELD_NAMES for name in values_select) or + TREE_FIELD_NAMES & annotation_select_keys + ) + + if tree_fields_present: + # We have tree fields, so we need to handle this specially + # Create clean queries without tree fields + features = self.connection.features + compilers = [] + + for query in self.query.combined_queries: + # Clone the query and ensure it's a regular Query (not TreeQuery) + cloned_query = query.clone() + if hasattr(cloned_query, '__class__') and cloned_query.__class__.__name__ == 'TreeQuery': + cloned_query.__class__ = Query + + # Clear any tree-related ordering that might have been set + cloned_query.clear_ordering(force=True) + + # Get a regular compiler for this query + compiler = cloned_query.get_compiler(self.using, self.connection, self.elide_empty) + compilers.append(compiler) + + # Check for slicing and ordering restrictions (copied from Django's implementation) + if not features.supports_slicing_ordering_in_compound: + for compiler in compilers: + if compiler.query.is_sliced: + from django.db import DatabaseError + raise DatabaseError( + "LIMIT/OFFSET not allowed in subqueries of compound statements." + ) + if compiler.get_order_by(): + from django.db import DatabaseError + raise DatabaseError( + "ORDER BY not allowed in subqueries of compound statements." + ) + elif self.query.is_sliced and combinator == "union": + for compiler in compilers: + compiler.elide_empty = False + + # Generate SQL for each part without tree fields + parts = () + for compiler in compilers: + try: + # Generate SQL for this sub-query + part_sql, part_args = compiler.as_sql(with_col_aliases=True) + if compiler.query.combinator: + # Handle nested combinators + if not features.supports_parentheses_in_compound: + part_sql = "SELECT * FROM ({})".format(part_sql) + elif ( + self.query.subquery + or not features.supports_slicing_ordering_in_compound + ): + part_sql = "({})".format(part_sql) + elif ( + self.query.subquery + and features.supports_slicing_ordering_in_compound + ): + part_sql = "({})".format(part_sql) + parts += ((part_sql, part_args),) + except Exception: # Django's EmptyResultSet would be caught here + # Handle empty results + if combinator == "union" or (combinator == "difference" and parts): + continue + raise + + if not parts: + from django.db.models.sql.compiler import EmptyResultSet + raise EmptyResultSet + + # Combine the parts + combinator_sql = self.connection.ops.set_operators[combinator] + if all and combinator == "union": + combinator_sql += " ALL" + + braces = "{}" + if not self.query.subquery and features.supports_slicing_ordering_in_compound: + braces = "({})" + + sql_parts, args_parts = zip(*parts) + result = [braces.format(part) for part in sql_parts] + + # Join with combinator + sql = (" %s " % combinator_sql).join(result) + params = [] + for part_args in args_parts: + params.extend(part_args) + + # Return in the same format as Django's get_combinator_sql: ([sql], params) + return [sql], params + else: + # No tree fields present, use Django's default implementation + return super().get_combinator_sql(combinator, all) + def converter(value, expression, connection, context=None): # context can be removed as soon as we only support Django>=2.0 From 02214d9aa524bf5dcd6eba46861c6a5013dd370e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:28:26 +0000 Subject: [PATCH 3/3] Fix union/intersection/difference operations with tree fields Co-authored-by: matthiask <2627+matthiask@users.noreply.github.com> --- tests/testapp/test_queries.py | 42 +++++++ tree_queries/compiler.py | 210 ++++++++++++++-------------------- 2 files changed, 130 insertions(+), 122 deletions(-) diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 3f8b9cd..644ec8c 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -1055,3 +1055,45 @@ def test_tree_fields_optimization(self): child2_2 = next(obj for obj in results if obj.name == "2-2") assert child2_2.tree_names == ["root", "2", "2-2"] + + def test_combinator_operations(self): + """Test that union, intersection, and difference work with tree queries""" + tree = self.create_tree() + + # Test union operation + qs1 = Model.objects.with_tree_fields().filter(name__in=["root", "1"]) + qs2 = Model.objects.with_tree_fields().filter(name__in=["2", "2-1"]) + union_result = list(qs1.union(qs2)) + + # Should have 4 unique objects + assert len(union_result) == 4 + names = {obj.name for obj in union_result} + assert names == {"root", "1", "2", "2-1"} + + # Tree fields should not be available in combinator queries + assert not hasattr(union_result[0], 'tree_depth') + + # Test intersection operation + qs3 = Model.objects.with_tree_fields().filter(name__in=["1", "2", "1-1"]) + qs4 = Model.objects.with_tree_fields().filter(name__in=["2", "2-1", "2-2"]) + intersect_result = list(qs3.intersection(qs4)) + + # Should have 1 object in common + assert len(intersect_result) == 1 + assert intersect_result[0].name == "2" + + # Test difference operation + qs5 = Model.objects.with_tree_fields().filter(name__in=["1", "2", "1-1"]) + qs6 = Model.objects.with_tree_fields().filter(name__in=["2"]) + diff_result = list(qs5.difference(qs6)) + + # Should have 2 objects (1 and 1-1) + assert len(diff_result) == 2 + names = {obj.name for obj in diff_result} + assert names == {"1", "1-1"} + + # Verify that regular tree queries still work normally + regular_tree_result = list(Model.objects.with_tree_fields().filter(name="root")) + assert len(regular_tree_result) == 1 + assert hasattr(regular_tree_result[0], 'tree_depth') + assert regular_tree_result[0].tree_depth == 0 diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index 37dcc49..e46f9db 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -399,29 +399,10 @@ def as_sql(self, *args, **kwargs): ): return super().as_sql(*args, **kwargs) - # Check if this query is part of a combinator operation (union/intersection/difference) - # If so, we need to skip tree field generation entirely and use regular SQL compilation + # Check if this query is part of a combinator operation (union/intersection/difference). + # Tree fields (CTEs) are not compatible with combinator operations, so we fall back + # to regular query compilation without tree fields. if hasattr(self.query, 'combinator') and self.query.combinator: - # For combinator queries, we also need to clean up any tree-related extra fields - # that might have been added previously, and clear tree-related ordering - if hasattr(self.query, 'extra_select'): - # Remove tree fields from extra_select - tree_fields = {'tree_depth', 'tree_path', 'tree_ordering'} - for field_name in list(self.query.extra_select.keys()): - if field_name in tree_fields: - del self.query.extra_select[field_name] - - # Clear tree-related extra tables - if hasattr(self.query, 'extra_tables') and '__tree' in self.query.extra_tables: - self.query.extra_tables = tuple(t for t in self.query.extra_tables if t != '__tree') - - # Clear tree-related extra where clauses - if hasattr(self.query, 'extra_where'): - self.query.extra_where = [w for w in self.query.extra_where if '__tree' not in w] - - # Clear any tree-related ordering - self.query.clear_ordering(force=True) - return super().as_sql(*args, **kwargs) # The general idea is that if we have a summary query (e.g. .count()) @@ -595,117 +576,102 @@ def get_combinator_sql(self, combinator, all): """ Override combinator SQL generation to handle tree fields properly. - When using union/intersection/difference with tree queries, we need to - handle tree fields specially since they are computed fields that don't - exist as real model fields. + When using union/intersection/difference with tree queries, tree fields + (computed via CTEs) are incompatible with combinator operations. - The approach is to remove tree fields from the normalization process. + The solution is to ensure all queries in the combinator operation are + compiled as regular queries without tree functionality. """ - # Tree field names that should be excluded from combinator operations - TREE_FIELD_NAMES = {'tree_depth', 'tree_path', 'tree_ordering'} + # Convert all TreeQuery instances to regular Query instances + # and use regular SQLCompiler for all parts - # Get the current values that would be applied to sub-queries - extra_select_keys = set(self.query.extra_select.keys()) - values_select = self.query.values_select - annotation_select_keys = set(self.query.annotation_select.keys()) + features = self.connection.features + regular_compilers = [] - # Check if any tree fields are present - tree_fields_present = ( - TREE_FIELD_NAMES & extra_select_keys or - any(name in TREE_FIELD_NAMES for name in values_select) or - TREE_FIELD_NAMES & annotation_select_keys - ) - - if tree_fields_present: - # We have tree fields, so we need to handle this specially - # Create clean queries without tree fields - features = self.connection.features - compilers = [] + for query in self.query.combined_queries: + # Clone the query to avoid modifying the original + cloned_query = query.clone() - for query in self.query.combined_queries: - # Clone the query and ensure it's a regular Query (not TreeQuery) - cloned_query = query.clone() - if hasattr(cloned_query, '__class__') and cloned_query.__class__.__name__ == 'TreeQuery': - cloned_query.__class__ = Query - - # Clear any tree-related ordering that might have been set - cloned_query.clear_ordering(force=True) - - # Get a regular compiler for this query - compiler = cloned_query.get_compiler(self.using, self.connection, self.elide_empty) - compilers.append(compiler) + # Convert TreeQuery to regular Query + if hasattr(cloned_query, '__class__') and cloned_query.__class__.__name__ == 'TreeQuery': + cloned_query.__class__ = Query - # Check for slicing and ordering restrictions (copied from Django's implementation) - if not features.supports_slicing_ordering_in_compound: - for compiler in compilers: - if compiler.query.is_sliced: - from django.db import DatabaseError - raise DatabaseError( - "LIMIT/OFFSET not allowed in subqueries of compound statements." - ) - if compiler.get_order_by(): - from django.db import DatabaseError - raise DatabaseError( - "ORDER BY not allowed in subqueries of compound statements." - ) - elif self.query.is_sliced and combinator == "union": - for compiler in compilers: - compiler.elide_empty = False + # Clear any tree-related orderings that might cause issues + cloned_query.clear_ordering(force=True) - # Generate SQL for each part without tree fields - parts = () - for compiler in compilers: - try: - # Generate SQL for this sub-query - part_sql, part_args = compiler.as_sql(with_col_aliases=True) - if compiler.query.combinator: - # Handle nested combinators - if not features.supports_parentheses_in_compound: - part_sql = "SELECT * FROM ({})".format(part_sql) - elif ( - self.query.subquery - or not features.supports_slicing_ordering_in_compound - ): - part_sql = "({})".format(part_sql) + # Get a regular SQLCompiler for this query + compiler = cloned_query.get_compiler(self.using, self.connection, self.elide_empty) + regular_compilers.append(compiler) + + # Handle slicing and ordering restrictions (from Django's implementation) + if not features.supports_slicing_ordering_in_compound: + for compiler in regular_compilers: + if compiler.query.is_sliced: + from django.db import DatabaseError + raise DatabaseError( + "LIMIT/OFFSET not allowed in subqueries of compound statements." + ) + if compiler.get_order_by(): + from django.db import DatabaseError + raise DatabaseError( + "ORDER BY not allowed in subqueries of compound statements." + ) + elif self.query.is_sliced and combinator == "union": + for compiler in regular_compilers: + compiler.elide_empty = False + + # Generate SQL for each part + parts = [] + for compiler in regular_compilers: + try: + part_sql, part_args = compiler.as_sql(with_col_aliases=True) + + # Handle nested combinators and subqueries (from Django's implementation) + if compiler.query.combinator: + if not features.supports_parentheses_in_compound: + part_sql = "SELECT * FROM ({})".format(part_sql) elif ( self.query.subquery - and features.supports_slicing_ordering_in_compound + or not features.supports_slicing_ordering_in_compound ): part_sql = "({})".format(part_sql) - parts += ((part_sql, part_args),) - except Exception: # Django's EmptyResultSet would be caught here - # Handle empty results - if combinator == "union" or (combinator == "difference" and parts): - continue - raise - - if not parts: - from django.db.models.sql.compiler import EmptyResultSet - raise EmptyResultSet - - # Combine the parts - combinator_sql = self.connection.ops.set_operators[combinator] - if all and combinator == "union": - combinator_sql += " ALL" - - braces = "{}" - if not self.query.subquery and features.supports_slicing_ordering_in_compound: - braces = "({})" - - sql_parts, args_parts = zip(*parts) - result = [braces.format(part) for part in sql_parts] - - # Join with combinator - sql = (" %s " % combinator_sql).join(result) - params = [] - for part_args in args_parts: - params.extend(part_args) - - # Return in the same format as Django's get_combinator_sql: ([sql], params) - return [sql], params - else: - # No tree fields present, use Django's default implementation - return super().get_combinator_sql(combinator, all) + elif ( + self.query.subquery + and features.supports_slicing_ordering_in_compound + ): + part_sql = "({})".format(part_sql) + + parts.append((part_sql, part_args)) + except Exception: + # Handle empty results (similar to Django's EmptyResultSet handling) + if combinator == "union" or (combinator == "difference" and parts): + continue + raise + + if not parts: + from django.db.models.sql.compiler import EmptyResultSet + raise EmptyResultSet + + # Combine the parts + combinator_sql = self.connection.ops.set_operators[combinator] + if all and combinator == "union": + combinator_sql += " ALL" + + braces = "{}" + if not self.query.subquery and features.supports_slicing_ordering_in_compound: + braces = "({})" + + sql_parts, args_parts = zip(*parts) + result = [braces.format(part) for part in sql_parts] + + # Join the parts with the combinator + final_sql = (" %s " % combinator_sql).join(result) + final_params = [] + for part_args in args_parts: + final_params.extend(part_args) + + # Return in Django's expected format: [sql], params + return [final_sql], final_params def converter(value, expression, connection, context=None):