Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,128 @@ def test_values_list_ancestors(self):
Model.objects.ancestors(tree.child2_1).values_list("parent", flat=True)
) == [tree.root.parent_id, tree.child2.parent_id]

def test_values_with_tree_fields(self):
"""Test that tree fields can be included in values() calls"""
tree = self.create_tree()

# Test values() with tree_depth
depth_values = list(Model.objects.with_tree_fields().values("name", "tree_depth"))
expected_depths = [
{"name": "root", "tree_depth": 0},
{"name": "1", "tree_depth": 1},
{"name": "1-1", "tree_depth": 2},
{"name": "2", "tree_depth": 1},
{"name": "2-1", "tree_depth": 2},
{"name": "2-2", "tree_depth": 2},
]
assert depth_values == expected_depths

# Test values() with only tree fields
tree_only = list(Model.objects.with_tree_fields().values("tree_depth"))
assert tree_only == [
{"tree_depth": 0},
{"tree_depth": 1},
{"tree_depth": 2},
{"tree_depth": 1},
{"tree_depth": 2},
{"tree_depth": 2},
]

# Test values() with tree_path
path_values = list(Model.objects.with_tree_fields().values("name", "tree_path"))
root_path = next(item for item in path_values if item["name"] == "root")
child2_2_path = next(item for item in path_values if item["name"] == "2-2")

assert root_path["tree_path"] == [tree.root.pk]
assert child2_2_path["tree_path"] == [tree.root.pk, tree.child2.pk, tree.child2_2.pk]

def test_values_rawsql_workaround_still_works(self):
"""Test that the old RawSQL workaround still works alongside the new functionality"""
tree = self.create_tree()

# Test the old RawSQL workaround mentioned in the issue
rawsql_values = list(
Model.objects.with_tree_fields().values(
"name",
tree_depth=RawSQL("tree_depth", ()),
tree_path=RawSQL("tree_path", ()),
)
)

# Find specific nodes to test
root_item = next(item for item in rawsql_values if item["name"] == "root")
child2_2_item = next(item for item in rawsql_values if item["name"] == "2-2")

# Verify the RawSQL expressions work
assert root_item["tree_depth"] == 0
assert child2_2_item["tree_depth"] == 2

# tree_path format depends on database backend, so just check it's not None
assert root_item["tree_path"] is not None
assert child2_2_item["tree_path"] is not None

def test_values_list_with_tree_fields(self):
"""Test that tree fields work with values_list() calls too"""
tree = self.create_tree()

# Test values_list() with tree_depth
depth_list = list(Model.objects.with_tree_fields().values_list("tree_depth", flat=True))
assert depth_list == [0, 1, 2, 1, 2, 2]

# Test values_list() with multiple fields including tree fields
name_depth_list = list(Model.objects.with_tree_fields().values_list("name", "tree_depth"))
expected_name_depth = [
("root", 0),
("1", 1),
("1-1", 2),
("2", 1),
("2-1", 2),
("2-2", 2),
]
assert name_depth_list == expected_name_depth

def test_values_with_custom_tree_fields(self):
"""Test that custom tree fields from tree_fields() work with values()"""
tree = self.create_tree()

# Test values() with custom tree fields
custom_values = list(
Model.objects.tree_fields(tree_names="name").values("name", "tree_names")
)

# Find specific nodes to test
root_item = next(item for item in custom_values if item["name"] == "root")
child2_2_item = next(item for item in custom_values if item["name"] == "2-2")

# Verify custom tree fields work
assert root_item["tree_names"] == ["root"]
assert child2_2_item["tree_names"] == ["root", "2", "2-2"]

def test_values_no_args_excludes_tree_fields(self):
"""Test that values() with no arguments excludes tree fields (maintains current behavior)"""
tree = self.create_tree()

# values() with no arguments should NOT include tree fields, even when tree fields are enabled
# This maintains backward compatibility with existing behavior
all_values = list(Model.objects.with_tree_fields().values())

# Should have at least one record
assert len(all_values) > 0

# Check the first record has only model fields, not tree fields
first_record = all_values[0]

# Should have model fields
assert "custom_id" in first_record
assert "name" in first_record
assert "order" in first_record
assert "parent_id" in first_record

# Should NOT have tree fields
assert "tree_depth" not in first_record
assert "tree_path" not in first_record
assert "tree_ordering" not in first_record

def test_loops(self):
tree = self.create_tree()
tree.root.parent_id = tree.child1.pk
Expand Down
18 changes: 15 additions & 3 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,22 @@ def as_sql(self, *args, **kwargs):
}
# Add custom tree fields for both simple and complex CTEs
select.update({name: f"__tree.{name}" for name in tree_fields})

# Determine which tree fields to include in the select clause
if skip_tree_fields:
# Skip tree fields for summary queries
select = {}
elif self.query.values_select is not None:
# For values() queries (including values() with no args),
# only include tree fields that were specifically requested
requested_fields = set(self.query.values_select)
available_tree_fields = {"tree_depth", "tree_path", "tree_ordering"} | set(tree_fields.keys())
requested_tree_fields = requested_fields & available_tree_fields
select = {name: expr for name, expr in select.items() if name in requested_tree_fields}
# else: keep all tree fields for normal queries (select stays as-is)

self.query.add_extra(
# Do not add extra fields to the select statement when it is a
# summary query or when using .values() or .values_list()
select={} if skip_tree_fields or self.query.values_select else select,
select=select,
select_params=None,
where=["__tree.tree_pk = {db_table}.{pk}".format(**tree_params)],
params=None,
Expand Down
Loading