Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Features and limitations
methods for ordering siblings and filtering ancestors and descendants. Other
features may be useful, but will not be added to the package just because
it's possible to do so.
- Includes ``add_related_count()`` method for counting related objects with
support for cumulative counting across tree hierarchies (replacement for
django-mptt's method of the same name).
- Little code, and relatively simple when compared to other tree
management solutions for Django. No redundant values so the only way
to end up with corrupt data is by introducing a loop in the tree
Expand Down Expand Up @@ -218,6 +221,69 @@ before the recursive CTE processes relationships, dramatically improving perform
for large datasets compared to using regular ``filter()`` after ``with_tree_fields()``.
Best used for selecting complete trees or tree sections rather than scattered nodes.


Counting related objects
------------------------

django-tree-queries provides ``add_related_count()`` as a replacement for
django-mptt's method of the same name. This method annotates tree nodes with
counts of related objects, with support for cumulative counting that includes
counts from descendant nodes.

.. code-block:: python

# Example models
class Region(TreeNode):
name = models.CharField(max_length=100)

class Site(models.Model):
name = models.CharField(max_length=100)
region = models.ForeignKey(Region, on_delete=models.CASCADE, related_name="sites")

# Count sites directly assigned to each region (non-cumulative)
regions = Region.objects.add_related_count(
Region.objects.all(),
Site,
'region',
'site_count',
cumulative=False
)

# Each region will have a site_count attribute with direct counts only
for region in regions:
print(f"{region.name}: {region.site_count} direct sites")

# Count sites assigned to each region and all its descendants (cumulative)
regions_cumulative = Region.objects.add_related_count(
Region.objects.all(),
Site,
'region',
'total_sites',
cumulative=True
)

# Each region will have a total_sites attribute with cumulative counts
for region in regions_cumulative:
print(f"{region.name}: {region.total_sites} total sites")

The method signature is:

.. code-block:: python

add_related_count(queryset, rel_model, rel_field, count_attr, cumulative=False)

Parameters:

- ``queryset``: The queryset to annotate (typically ``self``)
- ``rel_model``: The related model to count instances of
- ``rel_field``: Field name on ``rel_model`` that points to the tree model
- ``count_attr``: Name of the annotation to add to each instance
- ``cumulative``: If ``True``, counts include related objects from descendants

The implementation automatically detects the database backend and uses optimized
queries for PostgreSQL (with array operations) while maintaining compatibility
with SQLite, MySQL, and MariaDB.

Note that the tree queryset doesn't support all types of queries Django
supports. For example, updating all descendants directly isn't supported. The
reason for that is that the recursive CTE isn't added to the UPDATE query
Expand Down
20 changes: 20 additions & 0 deletions tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,23 @@ class OneToOneRelatedOrder(models.Model):

def __str__(self):
return ""


# Models for testing add_related_count functionality
class Region(TreeNode):
name = models.CharField(max_length=100)

def __str__(self):
return self.name


class Site(models.Model):
name = models.CharField(max_length=100)
region = models.ForeignKey(
Region,
on_delete=models.CASCADE,
related_name="sites",
)

def __str__(self):
return self.name
91 changes: 91 additions & 0 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
MultiOrderedModel,
OneToOneRelatedOrder,
ReferenceModel,
Region,
RelatedOrderModel,
Site,
StringOrderedModel,
TreeNodeIsOptional,
UnorderedModel,
Expand Down Expand Up @@ -1055,3 +1057,92 @@ def test_tree_fields_optimization(self):

child2_2 = next(obj for obj in results if obj.name == "2-2")
assert child2_2.tree_names == ["root", "2", "2-2"]

def test_add_related_count_non_cumulative(self):
"""Test add_related_count with cumulative=False"""
# Create regions
country = Region.objects.create(name="USA")
state1 = Region.objects.create(name="California", parent=country)
state2 = Region.objects.create(name="Texas", parent=country)
city1 = Region.objects.create(name="San Francisco", parent=state1)
city2 = Region.objects.create(name="Los Angeles", parent=state1)

# Create sites - some directly associated with each region
Site.objects.create(name="Site 1", region=city1)
Site.objects.create(name="Site 2", region=city1) # 2 sites for city1
Site.objects.create(name="Site 3", region=city2) # 1 site for city2
Site.objects.create(name="Site 4", region=state1) # 1 site directly for state1
Site.objects.create(name="Site 5", region=country) # 1 site directly for country

# Test non-cumulative counting
result = Region.objects.add_related_count(
Region.objects.all(),
Site,
'region',
'site_count',
cumulative=False
)

regions_by_name = {r.name: r for r in result}

# Check direct counts only
assert regions_by_name["San Francisco"].site_count == 2
assert regions_by_name["Los Angeles"].site_count == 1
assert regions_by_name["California"].site_count == 1
assert regions_by_name["Texas"].site_count == 0
assert regions_by_name["USA"].site_count == 1

def test_add_related_count_cumulative(self):
"""Test add_related_count with cumulative=True"""
# Create regions
country = Region.objects.create(name="USA")
state1 = Region.objects.create(name="California", parent=country)
state2 = Region.objects.create(name="Texas", parent=country)
city1 = Region.objects.create(name="San Francisco", parent=state1)
city2 = Region.objects.create(name="Los Angeles", parent=state1)
city3 = Region.objects.create(name="Houston", parent=state2)

# Create sites
Site.objects.create(name="Site 1", region=city1)
Site.objects.create(name="Site 2", region=city1) # 2 sites for city1
Site.objects.create(name="Site 3", region=city2) # 1 site for city2
Site.objects.create(name="Site 4", region=city3) # 1 site for city3
Site.objects.create(name="Site 5", region=state1) # 1 site directly for state1
Site.objects.create(name="Site 6", region=country) # 1 site directly for country

# Test cumulative counting
result = Region.objects.add_related_count(
Region.objects.all(),
Site,
'region',
'site_count',
cumulative=True
)

regions_by_name = {r.name: r for r in result}

# Check cumulative counts
assert regions_by_name["San Francisco"].site_count == 2 # Just its own
assert regions_by_name["Los Angeles"].site_count == 1 # Just its own
assert regions_by_name["Houston"].site_count == 1 # Just its own
assert regions_by_name["California"].site_count == 4 # 1 direct + 2 from SF + 1 from LA
assert regions_by_name["Texas"].site_count == 1 # 1 from Houston
assert regions_by_name["USA"].site_count == 6 # All sites

def test_add_related_count_empty_tree(self):
"""Test add_related_count with no related objects"""
# Create regions but no sites
country = Region.objects.create(name="USA")
state = Region.objects.create(name="California", parent=country)

# Test with no related objects
result = Region.objects.add_related_count(
Region.objects.all(),
Site,
'region',
'site_count',
cumulative=True
)

for region in result:
assert region.site_count == 0 or region.site_count is None
89 changes: 89 additions & 0 deletions tree_queries/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.db import connections, models
from django.db.models import Count, OuterRef, Subquery
from django.db.models.sql.query import Query

from tree_queries.compiler import SEPARATOR, TreeQuery
Expand Down Expand Up @@ -139,3 +140,91 @@ def descendants(self, of, *, include_self=False):
if not include_self:
return queryset.exclude(pk=pk(of))
return queryset

def add_related_count(
self,
queryset,
rel_model,
rel_field,
count_attr,
cumulative=False,
):
"""
Annotates each instance in the queryset with a count of related objects.

This is a replacement for django-mptt's add_related_count method, adapted
to work with django-tree-queries' CTE-based approach.

Args:
queryset: The queryset to annotate
rel_model: The related model to count instances of
rel_field: Field name on rel_model that points to the tree model
count_attr: Name of the annotation to add to each instance
cumulative: If True, count includes related objects from descendants

Returns:
An annotated queryset

Example:
Region.objects.add_related_count(
Region.objects.all(),
Site,
'region',
'site_count',
cumulative=True
)
"""
# If not cumulative, use simple annotation based on direct relationships
if not cumulative:
# Get the related field to find the reverse relationship name
rel_field_obj = rel_model._meta.get_field(rel_field)
if hasattr(rel_field_obj, 'remote_field') and rel_field_obj.remote_field:
related_name = rel_field_obj.remote_field.related_name
if related_name:
# Use the explicitly defined related_name
return queryset.annotate(**{
count_attr: Count(related_name, distinct=True)
})

# Fall back to generic reverse lookup
reverse_name = f"{rel_model._meta.model_name}_set"
return queryset.annotate(**{
count_attr: Count(reverse_name, distinct=True)
})

# For cumulative counts, we need to count related objects for each node
# and all its descendants using tree_path
base_queryset = queryset.with_tree_fields()
connection = connections[queryset.db]

if connection.vendor == "postgresql":
# PostgreSQL: Use array operations with tree_path
# Create a subquery that gets all descendants of each node (including self)
# and counts their related objects
descendants_subquery = self.model.objects.with_tree_fields().extra(
where=["%s = ANY(__tree.tree_path)"],
params=[OuterRef('pk')]
).values('pk')

count_subquery = Subquery(
rel_model.objects.filter(
**{f"{rel_field}__in": descendants_subquery}
).aggregate(total=Count('pk')).values('total')[:1]
)
else:
# Other databases: Use string operations on tree_path
# Find nodes whose tree_path contains the current node's pk
descendants_subquery = self.model.objects.with_tree_fields().extra(
where=[
f'instr(__tree.tree_path, "{SEPARATOR}" || %s || "{SEPARATOR}") <> 0'
],
params=[OuterRef('pk')]
).values('pk')

count_subquery = Subquery(
rel_model.objects.filter(
**{f"{rel_field}__in": descendants_subquery}
).aggregate(total=Count('pk')).values('total')[:1]
)

return base_queryset.annotate(**{count_attr: count_subquery})