From 220ada9b78d81806b4e5b7c9ccdd82f3343303fd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:06:04 +0000 Subject: [PATCH 1/3] Initial plan From ed4977412d12251bd8a5a0be09afc06c797d2ac9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:14:00 +0000 Subject: [PATCH 2/3] Implement add_related_count method for TreeQuerySet Co-authored-by: matthiask <2627+matthiask@users.noreply.github.com> --- tests/testapp/models.py | 20 ++++++++ tests/testapp/test_queries.py | 91 +++++++++++++++++++++++++++++++++++ tree_queries/query.py | 89 ++++++++++++++++++++++++++++++++++ 3 files changed, 200 insertions(+) diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 7cbfe3a..639f210 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -139,3 +139,23 @@ class OneToOneRelatedOrder(models.Model): def __str__(self): return "" + + +# Models for testing add_related_count functionality +class Region(TreeNode): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +class Site(models.Model): + name = models.CharField(max_length=100) + region = models.ForeignKey( + Region, + on_delete=models.CASCADE, + related_name="sites", + ) + + def __str__(self): + return self.name diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 3f8b9cd..3fce1d8 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -19,7 +19,9 @@ MultiOrderedModel, OneToOneRelatedOrder, ReferenceModel, + Region, RelatedOrderModel, + Site, StringOrderedModel, TreeNodeIsOptional, UnorderedModel, @@ -1055,3 +1057,92 @@ def test_tree_fields_optimization(self): child2_2 = next(obj for obj in results if obj.name == "2-2") assert child2_2.tree_names == ["root", "2", "2-2"] + + def test_add_related_count_non_cumulative(self): + """Test add_related_count with cumulative=False""" + # Create regions + country = Region.objects.create(name="USA") + state1 = Region.objects.create(name="California", parent=country) + state2 = Region.objects.create(name="Texas", parent=country) + city1 = Region.objects.create(name="San Francisco", parent=state1) + city2 = Region.objects.create(name="Los Angeles", parent=state1) + + # Create sites - some directly associated with each region + Site.objects.create(name="Site 1", region=city1) + Site.objects.create(name="Site 2", region=city1) # 2 sites for city1 + Site.objects.create(name="Site 3", region=city2) # 1 site for city2 + Site.objects.create(name="Site 4", region=state1) # 1 site directly for state1 + Site.objects.create(name="Site 5", region=country) # 1 site directly for country + + # Test non-cumulative counting + result = Region.objects.add_related_count( + Region.objects.all(), + Site, + 'region', + 'site_count', + cumulative=False + ) + + regions_by_name = {r.name: r for r in result} + + # Check direct counts only + assert regions_by_name["San Francisco"].site_count == 2 + assert regions_by_name["Los Angeles"].site_count == 1 + assert regions_by_name["California"].site_count == 1 + assert regions_by_name["Texas"].site_count == 0 + assert regions_by_name["USA"].site_count == 1 + + def test_add_related_count_cumulative(self): + """Test add_related_count with cumulative=True""" + # Create regions + country = Region.objects.create(name="USA") + state1 = Region.objects.create(name="California", parent=country) + state2 = Region.objects.create(name="Texas", parent=country) + city1 = Region.objects.create(name="San Francisco", parent=state1) + city2 = Region.objects.create(name="Los Angeles", parent=state1) + city3 = Region.objects.create(name="Houston", parent=state2) + + # Create sites + Site.objects.create(name="Site 1", region=city1) + Site.objects.create(name="Site 2", region=city1) # 2 sites for city1 + Site.objects.create(name="Site 3", region=city2) # 1 site for city2 + Site.objects.create(name="Site 4", region=city3) # 1 site for city3 + Site.objects.create(name="Site 5", region=state1) # 1 site directly for state1 + Site.objects.create(name="Site 6", region=country) # 1 site directly for country + + # Test cumulative counting + result = Region.objects.add_related_count( + Region.objects.all(), + Site, + 'region', + 'site_count', + cumulative=True + ) + + regions_by_name = {r.name: r for r in result} + + # Check cumulative counts + assert regions_by_name["San Francisco"].site_count == 2 # Just its own + assert regions_by_name["Los Angeles"].site_count == 1 # Just its own + assert regions_by_name["Houston"].site_count == 1 # Just its own + assert regions_by_name["California"].site_count == 4 # 1 direct + 2 from SF + 1 from LA + assert regions_by_name["Texas"].site_count == 1 # 1 from Houston + assert regions_by_name["USA"].site_count == 6 # All sites + + def test_add_related_count_empty_tree(self): + """Test add_related_count with no related objects""" + # Create regions but no sites + country = Region.objects.create(name="USA") + state = Region.objects.create(name="California", parent=country) + + # Test with no related objects + result = Region.objects.add_related_count( + Region.objects.all(), + Site, + 'region', + 'site_count', + cumulative=True + ) + + for region in result: + assert region.site_count == 0 or region.site_count is None diff --git a/tree_queries/query.py b/tree_queries/query.py index d1b36dc..61a6915 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -1,4 +1,5 @@ from django.db import connections, models +from django.db.models import Count, OuterRef, Subquery from django.db.models.sql.query import Query from tree_queries.compiler import SEPARATOR, TreeQuery @@ -139,3 +140,91 @@ def descendants(self, of, *, include_self=False): if not include_self: return queryset.exclude(pk=pk(of)) return queryset + + def add_related_count( + self, + queryset, + rel_model, + rel_field, + count_attr, + cumulative=False, + ): + """ + Annotates each instance in the queryset with a count of related objects. + + This is a replacement for django-mptt's add_related_count method, adapted + to work with django-tree-queries' CTE-based approach. + + Args: + queryset: The queryset to annotate + rel_model: The related model to count instances of + rel_field: Field name on rel_model that points to the tree model + count_attr: Name of the annotation to add to each instance + cumulative: If True, count includes related objects from descendants + + Returns: + An annotated queryset + + Example: + Region.objects.add_related_count( + Region.objects.all(), + Site, + 'region', + 'site_count', + cumulative=True + ) + """ + # If not cumulative, use simple annotation based on direct relationships + if not cumulative: + # Get the related field to find the reverse relationship name + rel_field_obj = rel_model._meta.get_field(rel_field) + if hasattr(rel_field_obj, 'remote_field') and rel_field_obj.remote_field: + related_name = rel_field_obj.remote_field.related_name + if related_name: + # Use the explicitly defined related_name + return queryset.annotate(**{ + count_attr: Count(related_name, distinct=True) + }) + + # Fall back to generic reverse lookup + reverse_name = f"{rel_model._meta.model_name}_set" + return queryset.annotate(**{ + count_attr: Count(reverse_name, distinct=True) + }) + + # For cumulative counts, we need to count related objects for each node + # and all its descendants using tree_path + base_queryset = queryset.with_tree_fields() + connection = connections[queryset.db] + + if connection.vendor == "postgresql": + # PostgreSQL: Use array operations with tree_path + # Create a subquery that gets all descendants of each node (including self) + # and counts their related objects + descendants_subquery = self.model.objects.with_tree_fields().extra( + where=["%s = ANY(__tree.tree_path)"], + params=[OuterRef('pk')] + ).values('pk') + + count_subquery = Subquery( + rel_model.objects.filter( + **{f"{rel_field}__in": descendants_subquery} + ).aggregate(total=Count('pk')).values('total')[:1] + ) + else: + # Other databases: Use string operations on tree_path + # Find nodes whose tree_path contains the current node's pk + descendants_subquery = self.model.objects.with_tree_fields().extra( + where=[ + f'instr(__tree.tree_path, "{SEPARATOR}" || %s || "{SEPARATOR}") <> 0' + ], + params=[OuterRef('pk')] + ).values('pk') + + count_subquery = Subquery( + rel_model.objects.filter( + **{f"{rel_field}__in": descendants_subquery} + ).aggregate(total=Count('pk')).values('total')[:1] + ) + + return base_queryset.annotate(**{count_attr: count_subquery}) From 9d43a63e37ec84c40f5aadf745369e84f9352f37 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:16:30 +0000 Subject: [PATCH 3/3] Add comprehensive documentation for add_related_count method Co-authored-by: matthiask <2627+matthiask@users.noreply.github.com> --- README.rst | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/README.rst b/README.rst index 65a5594..f33d441 100644 --- a/README.rst +++ b/README.rst @@ -34,6 +34,9 @@ Features and limitations methods for ordering siblings and filtering ancestors and descendants. Other features may be useful, but will not be added to the package just because it's possible to do so. +- Includes ``add_related_count()`` method for counting related objects with + support for cumulative counting across tree hierarchies (replacement for + django-mptt's method of the same name). - Little code, and relatively simple when compared to other tree management solutions for Django. No redundant values so the only way to end up with corrupt data is by introducing a loop in the tree @@ -218,6 +221,69 @@ before the recursive CTE processes relationships, dramatically improving perform for large datasets compared to using regular ``filter()`` after ``with_tree_fields()``. Best used for selecting complete trees or tree sections rather than scattered nodes. + +Counting related objects +------------------------ + +django-tree-queries provides ``add_related_count()`` as a replacement for +django-mptt's method of the same name. This method annotates tree nodes with +counts of related objects, with support for cumulative counting that includes +counts from descendant nodes. + +.. code-block:: python + + # Example models + class Region(TreeNode): + name = models.CharField(max_length=100) + + class Site(models.Model): + name = models.CharField(max_length=100) + region = models.ForeignKey(Region, on_delete=models.CASCADE, related_name="sites") + + # Count sites directly assigned to each region (non-cumulative) + regions = Region.objects.add_related_count( + Region.objects.all(), + Site, + 'region', + 'site_count', + cumulative=False + ) + + # Each region will have a site_count attribute with direct counts only + for region in regions: + print(f"{region.name}: {region.site_count} direct sites") + + # Count sites assigned to each region and all its descendants (cumulative) + regions_cumulative = Region.objects.add_related_count( + Region.objects.all(), + Site, + 'region', + 'total_sites', + cumulative=True + ) + + # Each region will have a total_sites attribute with cumulative counts + for region in regions_cumulative: + print(f"{region.name}: {region.total_sites} total sites") + +The method signature is: + +.. code-block:: python + + add_related_count(queryset, rel_model, rel_field, count_attr, cumulative=False) + +Parameters: + +- ``queryset``: The queryset to annotate (typically ``self``) +- ``rel_model``: The related model to count instances of +- ``rel_field``: Field name on ``rel_model`` that points to the tree model +- ``count_attr``: Name of the annotation to add to each instance +- ``cumulative``: If ``True``, counts include related objects from descendants + +The implementation automatically detects the database backend and uses optimized +queries for PostgreSQL (with array operations) while maintaining compatibility +with SQLite, MySQL, and MariaDB. + Note that the tree queryset doesn't support all types of queries Django supports. For example, updating all descendants directly isn't supported. The reason for that is that the recursive CTE isn't added to the UPDATE query