diff --git a/README.rst b/README.rst index 65a5594..422da99 100644 --- a/README.rst +++ b/README.rst @@ -74,6 +74,36 @@ Usage - Create a manager using ``TreeQuerySet.as_manager(with_tree_fields=True)`` if you want to add tree fields to queries by default. + +**Important note about tree fields and object creation:** + +When using ``TreeQuerySet.as_manager(with_tree_fields=True)``, tree fields +(``tree_depth``, ``tree_path``, ``tree_ordering``) are automatically available +on instances returned by ``Model.objects.create()``. For example: + +.. code-block:: python + + class Node(TreeNode): + name = models.CharField(max_length=100) + objects = TreeQuerySet.as_manager(with_tree_fields=True) + + # Tree fields are available immediately after create() + root = Node.objects.create(name="Root") + print(root.tree_depth) # 0 + print(root.tree_path) # [root.pk] + + child = Node.objects.create(name="Child", parent=root) + print(child.tree_depth) # 1 + +However, ``Model.objects.bulk_create()`` does not provide tree fields on the +returned objects (this is expected behavior for performance reasons). + +**Regarding refresh_from_db():** + +The ``refresh_from_db()`` method may not restore tree fields for models that +don't set ``base_manager_name = "objects"`` in their Meta class. If you need +to refresh tree fields, use ``obj = Model.objects.get(pk=obj.pk)`` instead. + - Until documentation is more complete I'll have to refer you to the `test suite `_ diff --git a/tests/testapp/migrations/0001_initial.py b/tests/testapp/migrations/0001_initial.py new file mode 100644 index 0000000..942e6d5 --- /dev/null +++ b/tests/testapp/migrations/0001_initial.py @@ -0,0 +1,171 @@ +# Generated by Django 4.2.11 on 2025-09-24 11:11 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='AlwaysTreeQueryModelCategory', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ], + ), + migrations.CreateModel( + name='InheritParentModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.inheritparentmodel', verbose_name='parent')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='Model', + fields=[ + ('custom_id', models.AutoField(primary_key=True, serialize=False)), + ('order', models.PositiveIntegerField(default=0)), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.model', verbose_name='parent')), + ], + options={ + 'ordering': ('order',), + }, + ), + migrations.CreateModel( + name='RelatedOrderModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.relatedordermodel', verbose_name='parent')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='InheritChildModel', + fields=[ + ('inheritparentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='testapp.inheritparentmodel')), + ], + options={ + 'abstract': False, + }, + bases=('testapp.inheritparentmodel',), + ), + migrations.CreateModel( + name='InheritConcreteGrandChildModel', + fields=[ + ('inheritparentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='testapp.inheritparentmodel')), + ], + options={ + 'abstract': False, + }, + bases=('testapp.inheritparentmodel',), + ), + migrations.CreateModel( + name='OneToOneRelatedOrder', + fields=[ + ('relatedmodel', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='related', serialize=False, to='testapp.relatedordermodel')), + ('order', models.PositiveIntegerField(default=0)), + ], + ), + migrations.CreateModel( + name='UUIDModel', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.uuidmodel', verbose_name='parent')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='UnorderedModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.unorderedmodel', verbose_name='parent')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='TreeNodeIsOptional', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('parent', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='testapp.treenodeisoptional')), + ], + ), + migrations.CreateModel( + name='ReferenceModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('position', models.PositiveIntegerField(default=0)), + ('tree_field', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='testapp.model')), + ], + options={ + 'ordering': ('position',), + }, + ), + migrations.CreateModel( + name='MultiOrderedModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('first_position', models.PositiveIntegerField(default=0)), + ('second_position', models.PositiveIntegerField(default=0)), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.multiorderedmodel', verbose_name='parent')), + ], + options={ + 'ordering': ('first_position',), + }, + ), + migrations.CreateModel( + name='AlwaysTreeQueryModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('category', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='instances', to='testapp.alwaystreequerymodelcategory')), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.alwaystreequerymodel', verbose_name='parent')), + ('related', models.ManyToManyField(to='testapp.alwaystreequerymodel')), + ], + options={ + 'base_manager_name': 'objects', + }, + ), + migrations.CreateModel( + name='InheritGrandChildModel', + fields=[ + ('inheritchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='testapp.inheritchildmodel')), + ], + options={ + 'abstract': False, + }, + bases=('testapp.inheritchildmodel',), + ), + migrations.CreateModel( + name='StringOrderedModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='testapp.stringorderedmodel', verbose_name='parent')), + ], + options={ + 'ordering': ('name',), + 'unique_together': {('name', 'parent')}, + }, + ), + ] diff --git a/tests/testapp/migrations/__init__.py b/tests/testapp/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 3f8b9cd..1cc10c4 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -1055,3 +1055,61 @@ def test_tree_fields_optimization(self): child2_2 = next(obj for obj in results if obj.name == "2-2") assert child2_2.tree_names == ["root", "2", "2-2"] + + def test_create_with_tree_fields(self): + """Test that tree fields are available on instances returned by create() when with_tree_fields=True""" + + # Clear any existing data + AlwaysTreeQueryModel.objects.all().delete() + Model.objects.all().delete() + + # Test with AlwaysTreeQueryModel (has with_tree_fields=True by default) + root = AlwaysTreeQueryModel.objects.create(name="Root") + + # Tree fields should be available directly after create() + assert hasattr(root, 'tree_depth'), "tree_depth should be available after create() with with_tree_fields=True" + assert hasattr(root, 'tree_path'), "tree_path should be available after create() with with_tree_fields=True" + assert hasattr(root, 'tree_ordering'), "tree_ordering should be available after create() with with_tree_fields=True" + + # Verify the values are correct + assert root.tree_depth == 0, f"Root should have tree_depth=0, got {root.tree_depth}" + assert root.tree_path == [root.pk], f"Root tree_path should be [pk], got {root.tree_path}" + + # Test with child + child = AlwaysTreeQueryModel.objects.create(name="Child", parent=root) + assert hasattr(child, 'tree_depth'), "tree_depth should be available after create() for child" + assert child.tree_depth == 1, f"Child should have tree_depth=1, got {child.tree_depth}" + assert len(child.tree_path) == 2, f"Child tree_path should have 2 elements, got {child.tree_path}" + assert child.tree_path[0] == root.pk, f"Child tree_path[0] should be parent pk, got {child.tree_path}" + assert child.tree_path[1] == child.pk, f"Child tree_path[1] should be child pk, got {child.tree_path}" + + # Test with regular Model (doesn't have with_tree_fields=True by default) + regular_instance = Model.objects.create(name="Regular") + assert not hasattr(regular_instance, 'tree_depth'), "tree_depth should NOT be available when with_tree_fields=False" + + # But should be available when queried with tree fields + regular_with_fields = Model.objects.with_tree_fields().get(pk=regular_instance.pk) + assert hasattr(regular_with_fields, 'tree_depth'), "tree_depth should be available when explicitly requested" + + def test_bulk_create_tree_fields(self): + """Test that bulk_create behaves correctly with tree fields""" + + AlwaysTreeQueryModel.objects.all().delete() + + # bulk_create should work but won't have tree fields on returned objects + # (this is expected and documented behavior) + objs = AlwaysTreeQueryModel.objects.bulk_create([ + AlwaysTreeQueryModel(name="Bulk1"), + AlwaysTreeQueryModel(name="Bulk2"), + ]) + + # The returned objects won't have tree fields (expected) + for obj in objs: + assert not hasattr(obj, 'tree_depth'), "bulk_create objects should not have tree fields" + + # But when queried, they should have tree fields + queried_objs = list(AlwaysTreeQueryModel.objects.filter(name__startswith="Bulk")) + assert len(queried_objs) == 2 + for obj in queried_objs: + assert hasattr(obj, 'tree_depth'), "queried objects should have tree fields" + assert obj.tree_depth == 0, "bulk created root nodes should have tree_depth=0" diff --git a/tree_queries/query.py b/tree_queries/query.py index d1b36dc..1b410ce 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -17,6 +17,30 @@ def get_queryset(self): queryset = super().get_queryset() return queryset.with_tree_fields() if self._with_tree_fields else queryset + def create(self, **kwargs): + """ + Create a new object with the given kwargs, saving it to the database + and returning the created object. + + If the manager has with_tree_fields=True, the returned object will + have tree fields populated by re-querying from the database. + """ + obj = super().create(**kwargs) + + # If tree fields are enabled, re-fetch the object to get tree annotations + if self._with_tree_fields: + obj = self.get(pk=obj.pk) + + return obj + + def bulk_create(self, objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None): + """ + Create multiple objects efficiently. Note that tree fields will NOT be + available on the returned objects even when with_tree_fields=True, + since bulk operations don't return individual annotated instances. + """ + return super().bulk_create(objs, batch_size, ignore_conflicts, update_conflicts, update_fields, unique_fields) + class TreeQuerySet(models.QuerySet): def with_tree_fields(self, tree_fields=True): # noqa: FBT002