Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ Usage
- Create a manager using
``TreeQuerySet.as_manager(with_tree_fields=True)`` if you want to add
tree fields to queries by default.

**Important note about tree fields and object creation:**

When using ``TreeQuerySet.as_manager(with_tree_fields=True)``, tree fields
(``tree_depth``, ``tree_path``, ``tree_ordering``) are automatically available
on instances returned by ``Model.objects.create()``. For example:

.. code-block:: python

class Node(TreeNode):
name = models.CharField(max_length=100)
objects = TreeQuerySet.as_manager(with_tree_fields=True)

# Tree fields are available immediately after create()
root = Node.objects.create(name="Root")
print(root.tree_depth) # 0
print(root.tree_path) # [root.pk]

child = Node.objects.create(name="Child", parent=root)
print(child.tree_depth) # 1

However, ``Model.objects.bulk_create()`` does not provide tree fields on the
returned objects (this is expected behavior for performance reasons).

**Regarding refresh_from_db():**

The ``refresh_from_db()`` method may not restore tree fields for models that
don't set ``base_manager_name = "objects"`` in their Meta class. If you need
to refresh tree fields, use ``obj = Model.objects.get(pk=obj.pk)`` instead.

- Until documentation is more complete I'll have to refer you to the
`test suite
<https://github.com/matthiask/django-tree-queries/blob/main/tests/testapp/test_queries.py>`_
Expand Down
171 changes: 171 additions & 0 deletions tests/testapp/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Generated by Django 4.2.11 on 2025-09-24 11:11

from django.db import migrations, models
import django.db.models.deletion
import uuid


class Migration(migrations.Migration):

initial = True

dependencies = [
]

operations = [
migrations.CreateModel(
name='AlwaysTreeQueryModelCategory',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
],
),
migrations.CreateModel(
name='InheritParentModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.inheritparentmodel', verbose_name='parent')),
],
options={
'abstract': False,
},
),
migrations.CreateModel(
name='Model',
fields=[
('custom_id', models.AutoField(primary_key=True, serialize=False)),
('order', models.PositiveIntegerField(default=0)),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.model', verbose_name='parent')),
],
options={
'ordering': ('order',),
},
),
migrations.CreateModel(
name='RelatedOrderModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.relatedordermodel', verbose_name='parent')),
],
options={
'abstract': False,
},
),
migrations.CreateModel(
name='InheritChildModel',
fields=[
('inheritparentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='testapp.inheritparentmodel')),
],
options={
'abstract': False,
},
bases=('testapp.inheritparentmodel',),
),
migrations.CreateModel(
name='InheritConcreteGrandChildModel',
fields=[
('inheritparentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='testapp.inheritparentmodel')),
],
options={
'abstract': False,
},
bases=('testapp.inheritparentmodel',),
),
migrations.CreateModel(
name='OneToOneRelatedOrder',
fields=[
('relatedmodel', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='related', serialize=False, to='testapp.relatedordermodel')),
('order', models.PositiveIntegerField(default=0)),
],
),
migrations.CreateModel(
name='UUIDModel',
fields=[
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.uuidmodel', verbose_name='parent')),
],
options={
'abstract': False,
},
),
migrations.CreateModel(
name='UnorderedModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.unorderedmodel', verbose_name='parent')),
],
options={
'abstract': False,
},
),
migrations.CreateModel(
name='TreeNodeIsOptional',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('parent', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='testapp.treenodeisoptional')),
],
),
migrations.CreateModel(
name='ReferenceModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('position', models.PositiveIntegerField(default=0)),
('tree_field', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='testapp.model')),
],
options={
'ordering': ('position',),
},
),
migrations.CreateModel(
name='MultiOrderedModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('first_position', models.PositiveIntegerField(default=0)),
('second_position', models.PositiveIntegerField(default=0)),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.multiorderedmodel', verbose_name='parent')),
],
options={
'ordering': ('first_position',),
},
),
migrations.CreateModel(
name='AlwaysTreeQueryModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
('category', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='instances', to='testapp.alwaystreequerymodelcategory')),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.alwaystreequerymodel', verbose_name='parent')),
('related', models.ManyToManyField(to='testapp.alwaystreequerymodel')),
],
options={
'base_manager_name': 'objects',
},
),
migrations.CreateModel(
name='InheritGrandChildModel',
fields=[
('inheritchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='testapp.inheritchildmodel')),
],
options={
'abstract': False,
},
bases=('testapp.inheritchildmodel',),
),
migrations.CreateModel(
name='StringOrderedModel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.stringorderedmodel', verbose_name='parent')),
],
options={
'ordering': ('name',),
'unique_together': {('name', 'parent')},
},
),
]
Empty file.
58 changes: 58 additions & 0 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,61 @@ def test_tree_fields_optimization(self):

child2_2 = next(obj for obj in results if obj.name == "2-2")
assert child2_2.tree_names == ["root", "2", "2-2"]

def test_create_with_tree_fields(self):
"""Test that tree fields are available on instances returned by create() when with_tree_fields=True"""

# Clear any existing data
AlwaysTreeQueryModel.objects.all().delete()
Model.objects.all().delete()

# Test with AlwaysTreeQueryModel (has with_tree_fields=True by default)
root = AlwaysTreeQueryModel.objects.create(name="Root")

# Tree fields should be available directly after create()
assert hasattr(root, 'tree_depth'), "tree_depth should be available after create() with with_tree_fields=True"
assert hasattr(root, 'tree_path'), "tree_path should be available after create() with with_tree_fields=True"
assert hasattr(root, 'tree_ordering'), "tree_ordering should be available after create() with with_tree_fields=True"

# Verify the values are correct
assert root.tree_depth == 0, f"Root should have tree_depth=0, got {root.tree_depth}"
assert root.tree_path == [root.pk], f"Root tree_path should be [pk], got {root.tree_path}"

# Test with child
child = AlwaysTreeQueryModel.objects.create(name="Child", parent=root)
assert hasattr(child, 'tree_depth'), "tree_depth should be available after create() for child"
assert child.tree_depth == 1, f"Child should have tree_depth=1, got {child.tree_depth}"
assert len(child.tree_path) == 2, f"Child tree_path should have 2 elements, got {child.tree_path}"
assert child.tree_path[0] == root.pk, f"Child tree_path[0] should be parent pk, got {child.tree_path}"
assert child.tree_path[1] == child.pk, f"Child tree_path[1] should be child pk, got {child.tree_path}"

# Test with regular Model (doesn't have with_tree_fields=True by default)
regular_instance = Model.objects.create(name="Regular")
assert not hasattr(regular_instance, 'tree_depth'), "tree_depth should NOT be available when with_tree_fields=False"

# But should be available when queried with tree fields
regular_with_fields = Model.objects.with_tree_fields().get(pk=regular_instance.pk)
assert hasattr(regular_with_fields, 'tree_depth'), "tree_depth should be available when explicitly requested"

def test_bulk_create_tree_fields(self):
"""Test that bulk_create behaves correctly with tree fields"""

AlwaysTreeQueryModel.objects.all().delete()

# bulk_create should work but won't have tree fields on returned objects
# (this is expected and documented behavior)
objs = AlwaysTreeQueryModel.objects.bulk_create([
AlwaysTreeQueryModel(name="Bulk1"),
AlwaysTreeQueryModel(name="Bulk2"),
])

# The returned objects won't have tree fields (expected)
for obj in objs:
assert not hasattr(obj, 'tree_depth'), "bulk_create objects should not have tree fields"

# But when queried, they should have tree fields
queried_objs = list(AlwaysTreeQueryModel.objects.filter(name__startswith="Bulk"))
assert len(queried_objs) == 2
for obj in queried_objs:
assert hasattr(obj, 'tree_depth'), "queried objects should have tree fields"
assert obj.tree_depth == 0, "bulk created root nodes should have tree_depth=0"
24 changes: 24 additions & 0 deletions tree_queries/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@ def get_queryset(self):
queryset = super().get_queryset()
return queryset.with_tree_fields() if self._with_tree_fields else queryset

def create(self, **kwargs):
"""
Create a new object with the given kwargs, saving it to the database
and returning the created object.

If the manager has with_tree_fields=True, the returned object will
have tree fields populated by re-querying from the database.
"""
obj = super().create(**kwargs)

# If tree fields are enabled, re-fetch the object to get tree annotations
if self._with_tree_fields:
obj = self.get(pk=obj.pk)

return obj

def bulk_create(self, objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None):
"""
Create multiple objects efficiently. Note that tree fields will NOT be
available on the returned objects even when with_tree_fields=True,
since bulk operations don't return individual annotated instances.
"""
return super().bulk_create(objs, batch_size, ignore_conflicts, update_conflicts, update_fields, unique_fields)


class TreeQuerySet(models.QuerySet):
def with_tree_fields(self, tree_fields=True): # noqa: FBT002
Expand Down