22from typing import List , Set , Optional
33from abc import ABC
44
5- from .node_type import NodeType
5+ from .enums import NodeType , JoinType , SortOrder
66
77# ============================================================================
88# Base Node Structure
@@ -13,6 +13,31 @@ class Node(ABC):
1313 def __init__ (self , type : NodeType , children : Optional [Set ['Node' ]| List ['Node' ]] = None ):
1414 self .type = type
1515 self .children = children if children is not None else set ()
16+
17+ def __eq__ (self , other ):
18+ if not isinstance (other , Node ):
19+ return False
20+ if self .type != other .type :
21+ return False
22+ if len (self .children ) != len (other .children ):
23+ return False
24+ # Compare children
25+ if isinstance (self .children , set ) and isinstance (other .children , set ):
26+ return self .children == other .children
27+ elif isinstance (self .children , list ) and isinstance (other .children , list ):
28+ return self .children == other .children
29+ else :
30+ return False
31+
32+ def __hash__ (self ):
33+ # Make nodes hashable by using their type and a hash of their children
34+ if isinstance (self .children , set ):
35+ # For sets, create a deterministic hash by sorting children by their string representation
36+ children_hash = hash (tuple (sorted (self .children , key = lambda x : str (x ))))
37+ else :
38+ # For lists, just hash the tuple directly
39+ children_hash = hash (tuple (self .children ))
40+ return hash ((self .type , children_hash ))
1641
1742
1843# ============================================================================
@@ -25,6 +50,16 @@ def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs):
2550 super ().__init__ (NodeType .TABLE , ** kwargs )
2651 self .name = _name
2752 self .alias = _alias
53+
54+ def __eq__ (self , other ):
55+ if not isinstance (other , TableNode ):
56+ return False
57+ return (super ().__eq__ (other ) and
58+ self .name == other .name and
59+ self .alias == other .alias )
60+
61+ def __hash__ (self ):
62+ return hash ((super ().__hash__ (), self .name , self .alias ))
2863
2964
3065# TODO - including query structure arguments (similar to QueryNode) in constructor.
@@ -43,6 +78,17 @@ def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Opti
4378 self .alias = _alias
4479 self .parent_alias = _parent_alias
4580 self .parent = _parent
81+
82+ def __eq__ (self , other ):
83+ if not isinstance (other , ColumnNode ):
84+ return False
85+ return (super ().__eq__ (other ) and
86+ self .name == other .name and
87+ self .alias == other .alias and
88+ self .parent_alias == other .parent_alias )
89+
90+ def __hash__ (self ):
91+ return hash ((super ().__hash__ (), self .name , self .alias , self .parent_alias ))
4692
4793
4894class LiteralNode (Node ):
@@ -51,6 +97,15 @@ def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs):
5197 super ().__init__ (NodeType .LITERAL , ** kwargs )
5298 self .value = _value
5399
100+ def __eq__ (self , other ):
101+ if not isinstance (other , LiteralNode ):
102+ return False
103+ return (super ().__eq__ (other ) and
104+ self .value == other .value )
105+
106+ def __hash__ (self ):
107+ return hash ((super ().__hash__ (), self .value ))
108+
54109
55110class VarNode (Node ):
56111 """VarSQL variable node"""
@@ -72,37 +127,78 @@ def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwa
72127 children = [_left , _right ] if _right else [_left ]
73128 super ().__init__ (NodeType .OPERATOR , children = children , ** kwargs )
74129 self .name = _name
130+
131+ def __eq__ (self , other ):
132+ if not isinstance (other , OperatorNode ):
133+ return False
134+ return (super ().__eq__ (other ) and
135+ self .name == other .name )
136+
137+ def __hash__ (self ):
138+ return hash ((super ().__hash__ (), self .name ))
75139
76140
77141class FunctionNode (Node ):
78142 """Function call node"""
79- def __init__ (self , _name : str , _args : Optional [List [Node ]] = None , ** kwargs ):
143+ def __init__ (self , _name : str , _args : Optional [List [Node ]] = None , _alias : Optional [ str ] = None , ** kwargs ):
80144 if _args is None :
81145 _args = []
82146 super ().__init__ (NodeType .FUNCTION , children = _args , ** kwargs )
83147 self .name = _name
84-
148+ self .alias = _alias
149+
150+ def __eq__ (self , other ):
151+ if not isinstance (other , FunctionNode ):
152+ return False
153+ return (super ().__eq__ (other ) and
154+ self .name == other .name and
155+ self .alias == other .alias )
156+
157+ def __hash__ (self ):
158+ return hash ((super ().__hash__ (), self .name , self .alias ))
159+
160+
161+ class JoinNode (Node ):
162+ """JOIN clause node"""
163+ def __init__ (self , _left_table : 'TableNode' , _right_table : 'TableNode' , _join_type : JoinType = JoinType .INNER , _on_condition : Optional ['Node' ] = None , ** kwargs ):
164+ children = [_left_table , _right_table ]
165+ if _on_condition :
166+ children .append (_on_condition )
167+ super ().__init__ (NodeType .JOIN , children = children , ** kwargs )
168+ self .left_table = _left_table
169+ self .right_table = _right_table
170+ self .join_type = _join_type
171+ self .on_condition = _on_condition
172+
173+ def __eq__ (self , other ):
174+ if not isinstance (other , JoinNode ):
175+ return False
176+ return (super ().__eq__ (other ) and
177+ self .join_type == other .join_type )
178+
179+ def __hash__ (self ):
180+ return hash ((super ().__hash__ (), self .join_type ))
85181
86182# ============================================================================
87183# Query Structure Nodes
88184# ============================================================================
89185
90186class SelectNode (Node ):
91187 """SELECT clause node"""
92- def __init__ (self , _items : Set ['Node' ], ** kwargs ):
188+ def __init__ (self , _items : List ['Node' ], ** kwargs ):
93189 super ().__init__ (NodeType .SELECT , children = _items , ** kwargs )
94190
95191
96192# TODO - confine the valid NodeTypes as children of FromNode
97193class FromNode (Node ):
98194 """FROM clause node"""
99- def __init__ (self , _sources : Set ['Node' ], ** kwargs ):
195+ def __init__ (self , _sources : List ['Node' ], ** kwargs ):
100196 super ().__init__ (NodeType .FROM , children = _sources , ** kwargs )
101197
102198
103199class WhereNode (Node ):
104200 """WHERE clause node"""
105- def __init__ (self , _predicates : Set ['Node' ], ** kwargs ):
201+ def __init__ (self , _predicates : List ['Node' ], ** kwargs ):
106202 super ().__init__ (NodeType .WHERE , children = _predicates , ** kwargs )
107203
108204
@@ -114,13 +210,28 @@ def __init__(self, _items: List['Node'], **kwargs):
114210
115211class HavingNode (Node ):
116212 """HAVING clause node"""
117- def __init__ (self , _predicates : Set ['Node' ], ** kwargs ):
213+ def __init__ (self , _predicates : List ['Node' ], ** kwargs ):
118214 super ().__init__ (NodeType .HAVING , children = _predicates , ** kwargs )
119215
120216
217+ class OrderByItemNode (Node ):
218+ """Single ORDER BY item"""
219+ def __init__ (self , _column : Node , _sort : SortOrder = SortOrder .ASC , ** kwargs ):
220+ super ().__init__ (NodeType .ORDER_BY_ITEM , children = [_column ], ** kwargs )
221+ self .sort = _sort
222+
223+ def __eq__ (self , other ):
224+ if not isinstance (other , OrderByItemNode ):
225+ return False
226+ return (super ().__eq__ (other ) and
227+ self .sort == other .sort )
228+
229+ def __hash__ (self ):
230+ return hash ((super ().__hash__ (), self .sort ))
231+
121232class OrderByNode (Node ):
122233 """ORDER BY clause node"""
123- def __init__ (self , _items : List ['Node' ], ** kwargs ):
234+ def __init__ (self , _items : List [OrderByItemNode ], ** kwargs ):
124235 super ().__init__ (NodeType .ORDER_BY , children = _items , ** kwargs )
125236
126237
@@ -129,13 +240,31 @@ class LimitNode(Node):
129240 def __init__ (self , _limit : int , ** kwargs ):
130241 super ().__init__ (NodeType .LIMIT , ** kwargs )
131242 self .limit = _limit
243+
244+ def __eq__ (self , other ):
245+ if not isinstance (other , LimitNode ):
246+ return False
247+ return (super ().__eq__ (other ) and
248+ self .limit == other .limit )
249+
250+ def __hash__ (self ):
251+ return hash ((super ().__hash__ (), self .limit ))
132252
133253
134254class OffsetNode (Node ):
135255 """OFFSET clause node"""
136256 def __init__ (self , _offset : int , ** kwargs ):
137257 super ().__init__ (NodeType .OFFSET , ** kwargs )
138258 self .offset = _offset
259+
260+ def __eq__ (self , other ):
261+ if not isinstance (other , OffsetNode ):
262+ return False
263+ return (super ().__eq__ (other ) and
264+ self .offset == other .offset )
265+
266+ def __hash__ (self ):
267+ return hash ((super ().__hash__ (), self .offset ))
139268
140269
141270class QueryNode (Node ):
0 commit comments