Skip to content

Commit 725b19a

Browse files
authored
Merge pull request #620 from DagsHub/bug/duplicate-tree-composition
Bug: composing duplicate queries throws an error
2 parents fd0f8da + fb96cee commit 725b19a

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

dagshub/data_engine/model/query.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import enum
33
import logging
4+
import uuid
45
from typing import Optional, Union, Dict
56

67
import pytz
@@ -190,10 +191,26 @@ def compose(self, op: str, other: Optional[Union[str, int, float, "QueryFilterTr
190191
return
191192
composite_tree = Tree()
192193
root_node = composite_tree.create_node(op)
193-
composite_tree.paste(root_node.identifier, self._operand_tree)
194-
composite_tree.paste(root_node.identifier, other._operand_tree)
194+
195+
# Use deep clones to avoid potentially composing two of the same tree together
196+
# and hitting duplicate identifiers
197+
composite_tree.paste(root_node.identifier, self._clone_operand_tree())
198+
composite_tree.paste(root_node.identifier, other._clone_operand_tree())
195199
self._operand_tree = composite_tree
196200

201+
def _clone_operand_tree(self) -> Tree:
202+
"""
203+
Returns a deep copy of the operand tree, also changing all node identifiers.
204+
This allows to compose two of the same tree together.
205+
"""
206+
if self._operand_tree.root is None:
207+
return Tree()
208+
new_tree = Tree(self._operand_tree.subtree(self._operand_tree.root), deep=True)
209+
node_ids = [node.identifier for node in new_tree.all_nodes_itr()]
210+
for node_id in node_ids:
211+
new_tree.update_node(node_id, identifier=str(uuid.uuid4()))
212+
return new_tree
213+
197214
@property
198215
def _column_filter_node(self) -> Node:
199216
return next(self._operand_tree.filter_nodes(lambda n: n.tag == UNFILLED_NODE_TAG), None)

tests/data_engine/test_querying.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,40 @@ def test_datetime_is_null(ds):
685685
expected = {"query": {"filter": {"comparator": "IS_NULL", "key": "x", "value": "0", "valueType": "DATETIME"}}}
686686

687687
assert q.serialize_gql_query_input() == expected
688+
689+
690+
def test_duplicate_subquery(ds):
691+
add_int_fields(ds, "x")
692+
add_int_fields(ds, "y")
693+
ds2 = ds["x"] > 5
694+
695+
ds3 = (ds2["y"] < 10) | (ds2["y"] > 20)
696+
697+
expected = {
698+
"or": {
699+
"children": [
700+
{
701+
"and": {
702+
"children": [
703+
{"gt": {"data": {"field": "x", "value": 5}}},
704+
{"lt": {"data": {"field": "y", "value": 10}}},
705+
],
706+
"data": None,
707+
}
708+
},
709+
{
710+
"and": {
711+
"children": [
712+
{"gt": {"data": {"field": "x", "value": 5}}},
713+
{"gt": {"data": {"field": "y", "value": 20}}},
714+
],
715+
"data": None,
716+
}
717+
},
718+
],
719+
"data": None,
720+
}
721+
}
722+
723+
actual = ds3.get_query().filter.tree_to_dict()
724+
assert actual == expected

0 commit comments

Comments
 (0)