Skip to content

Commit 31ff011

Browse files
[Refactor] Improve query structure nodes (#88)
* imrpove ast * fix documentation * add node changes from parser * update node_type for join * add sort order enum * add __eq__ and __hash__ for OrderByItemNode and OrderByNode * add enum to join type * combine enum files * simplify OrderByNode * make tests use new enums file --------- Co-authored-by: Yihong Yu <[email protected]>
1 parent 3d7b980 commit 31ff011

File tree

5 files changed

+202
-42
lines changed

5 files changed

+202
-42
lines changed

core/ast/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This module provides the node types and classes for representing SQL query structures.
55
"""
66

7-
from .node_type import NodeType
7+
from .enums import NodeType
88
from .node import (
99
Node,
1010
TableNode,
@@ -18,6 +18,7 @@
1818
SelectNode,
1919
FromNode,
2020
WhereNode,
21+
JoinNode,
2122
GroupByNode,
2223
HavingNode,
2324
OrderByNode,
@@ -40,9 +41,11 @@
4041
'SelectNode',
4142
'FromNode',
4243
'WhereNode',
44+
'JoinNode',
4345
'GroupByNode',
4446
'HavingNode',
4547
'OrderByNode',
48+
'OrderByItemNode',
4649
'LimitNode',
4750
'OffsetNode',
4851
'QueryNode'

core/ast/enums.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from enum import Enum
2+
3+
# ============================================================================
4+
# Node Type Enumeration
5+
# ============================================================================
6+
7+
class NodeType(Enum):
8+
"""Node type enumeration"""
9+
10+
# Operands
11+
TABLE = "table"
12+
SUBQUERY = "subquery"
13+
COLUMN = "column"
14+
LITERAL = "literal"
15+
# VarSQL specific
16+
VAR = "var"
17+
VARSET = "varset"
18+
19+
# Operators
20+
OPERATOR = "operator"
21+
FUNCTION = "function"
22+
23+
# Query structure
24+
SELECT = "select"
25+
FROM = "from"
26+
WHERE = "where"
27+
JOIN = "join"
28+
GROUP_BY = "group_by"
29+
HAVING = "having"
30+
ORDER_BY = "order_by"
31+
ORDER_BY_ITEM = "order_by_item"
32+
LIMIT = "limit"
33+
OFFSET = "offset"
34+
QUERY = "query"
35+
36+
# ============================================================================
37+
# Join Type Enumeration
38+
# ============================================================================
39+
40+
class JoinType(Enum):
41+
"""Join type enumeration"""
42+
INNER = "inner"
43+
OUTER = "outer"
44+
LEFT = "left"
45+
RIGHT = "right"
46+
FULL = "full"
47+
CROSS = "cross"
48+
NATURAL = "natural"
49+
SEMI = "semi"
50+
ANTI = "anti"
51+
52+
53+
# ============================================================================
54+
# Sort Order Enumeration
55+
# ============================================================================
56+
57+
class SortOrder(Enum):
58+
"""Sort order enum"""
59+
ASC = "ASC"
60+
DESC = "DESC"

core/ast/node.py

Lines changed: 137 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Set, Optional
33
from abc import ABC
44

5-
from .node_type import NodeType
5+
from .enums import NodeType, JoinType, SortOrder
66

77
# ============================================================================
88
# Base Node Structure
@@ -13,6 +13,31 @@ class Node(ABC):
1313
def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None):
1414
self.type = type
1515
self.children = children if children is not None else set()
16+
17+
def __eq__(self, other):
18+
if not isinstance(other, Node):
19+
return False
20+
if self.type != other.type:
21+
return False
22+
if len(self.children) != len(other.children):
23+
return False
24+
# Compare children
25+
if isinstance(self.children, set) and isinstance(other.children, set):
26+
return self.children == other.children
27+
elif isinstance(self.children, list) and isinstance(other.children, list):
28+
return self.children == other.children
29+
else:
30+
return False
31+
32+
def __hash__(self):
33+
# Make nodes hashable by using their type and a hash of their children
34+
if isinstance(self.children, set):
35+
# For sets, create a deterministic hash by sorting children by their string representation
36+
children_hash = hash(tuple(sorted(self.children, key=lambda x: str(x))))
37+
else:
38+
# For lists, just hash the tuple directly
39+
children_hash = hash(tuple(self.children))
40+
return hash((self.type, children_hash))
1641

1742

1843
# ============================================================================
@@ -25,6 +50,16 @@ def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs):
2550
super().__init__(NodeType.TABLE, **kwargs)
2651
self.name = _name
2752
self.alias = _alias
53+
54+
def __eq__(self, other):
55+
if not isinstance(other, TableNode):
56+
return False
57+
return (super().__eq__(other) and
58+
self.name == other.name and
59+
self.alias == other.alias)
60+
61+
def __hash__(self):
62+
return hash((super().__hash__(), self.name, self.alias))
2863

2964

3065
# TODO - including query structure arguments (similar to QueryNode) in constructor.
@@ -43,6 +78,17 @@ def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Opti
4378
self.alias = _alias
4479
self.parent_alias = _parent_alias
4580
self.parent = _parent
81+
82+
def __eq__(self, other):
83+
if not isinstance(other, ColumnNode):
84+
return False
85+
return (super().__eq__(other) and
86+
self.name == other.name and
87+
self.alias == other.alias and
88+
self.parent_alias == other.parent_alias)
89+
90+
def __hash__(self):
91+
return hash((super().__hash__(), self.name, self.alias, self.parent_alias))
4692

4793

4894
class LiteralNode(Node):
@@ -51,6 +97,15 @@ def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs):
5197
super().__init__(NodeType.LITERAL, **kwargs)
5298
self.value = _value
5399

100+
def __eq__(self, other):
101+
if not isinstance(other, LiteralNode):
102+
return False
103+
return (super().__eq__(other) and
104+
self.value == other.value)
105+
106+
def __hash__(self):
107+
return hash((super().__hash__(), self.value))
108+
54109

55110
class VarNode(Node):
56111
"""VarSQL variable node"""
@@ -72,37 +127,78 @@ def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwa
72127
children = [_left, _right] if _right else [_left]
73128
super().__init__(NodeType.OPERATOR, children=children, **kwargs)
74129
self.name = _name
130+
131+
def __eq__(self, other):
132+
if not isinstance(other, OperatorNode):
133+
return False
134+
return (super().__eq__(other) and
135+
self.name == other.name)
136+
137+
def __hash__(self):
138+
return hash((super().__hash__(), self.name))
75139

76140

77141
class FunctionNode(Node):
78142
"""Function call node"""
79-
def __init__(self, _name: str, _args: Optional[List[Node]] = None, **kwargs):
143+
def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs):
80144
if _args is None:
81145
_args = []
82146
super().__init__(NodeType.FUNCTION, children=_args, **kwargs)
83147
self.name = _name
84-
148+
self.alias = _alias
149+
150+
def __eq__(self, other):
151+
if not isinstance(other, FunctionNode):
152+
return False
153+
return (super().__eq__(other) and
154+
self.name == other.name and
155+
self.alias == other.alias)
156+
157+
def __hash__(self):
158+
return hash((super().__hash__(), self.name, self.alias))
159+
160+
161+
class JoinNode(Node):
162+
"""JOIN clause node"""
163+
def __init__(self, _left_table: 'TableNode', _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs):
164+
children = [_left_table, _right_table]
165+
if _on_condition:
166+
children.append(_on_condition)
167+
super().__init__(NodeType.JOIN, children=children, **kwargs)
168+
self.left_table = _left_table
169+
self.right_table = _right_table
170+
self.join_type = _join_type
171+
self.on_condition = _on_condition
172+
173+
def __eq__(self, other):
174+
if not isinstance(other, JoinNode):
175+
return False
176+
return (super().__eq__(other) and
177+
self.join_type == other.join_type)
178+
179+
def __hash__(self):
180+
return hash((super().__hash__(), self.join_type))
85181

86182
# ============================================================================
87183
# Query Structure Nodes
88184
# ============================================================================
89185

90186
class SelectNode(Node):
91187
"""SELECT clause node"""
92-
def __init__(self, _items: Set['Node'], **kwargs):
188+
def __init__(self, _items: List['Node'], **kwargs):
93189
super().__init__(NodeType.SELECT, children=_items, **kwargs)
94190

95191

96192
# TODO - confine the valid NodeTypes as children of FromNode
97193
class FromNode(Node):
98194
"""FROM clause node"""
99-
def __init__(self, _sources: Set['Node'], **kwargs):
195+
def __init__(self, _sources: List['Node'], **kwargs):
100196
super().__init__(NodeType.FROM, children=_sources, **kwargs)
101197

102198

103199
class WhereNode(Node):
104200
"""WHERE clause node"""
105-
def __init__(self, _predicates: Set['Node'], **kwargs):
201+
def __init__(self, _predicates: List['Node'], **kwargs):
106202
super().__init__(NodeType.WHERE, children=_predicates, **kwargs)
107203

108204

@@ -114,13 +210,28 @@ def __init__(self, _items: List['Node'], **kwargs):
114210

115211
class HavingNode(Node):
116212
"""HAVING clause node"""
117-
def __init__(self, _predicates: Set['Node'], **kwargs):
213+
def __init__(self, _predicates: List['Node'], **kwargs):
118214
super().__init__(NodeType.HAVING, children=_predicates, **kwargs)
119215

120216

217+
class OrderByItemNode(Node):
218+
"""Single ORDER BY item"""
219+
def __init__(self, _column: Node, _sort: SortOrder = SortOrder.ASC, **kwargs):
220+
super().__init__(NodeType.ORDER_BY_ITEM, children=[_column], **kwargs)
221+
self.sort = _sort
222+
223+
def __eq__(self, other):
224+
if not isinstance(other, OrderByItemNode):
225+
return False
226+
return (super().__eq__(other) and
227+
self.sort == other.sort)
228+
229+
def __hash__(self):
230+
return hash((super().__hash__(), self.sort))
231+
121232
class OrderByNode(Node):
122233
"""ORDER BY clause node"""
123-
def __init__(self, _items: List['Node'], **kwargs):
234+
def __init__(self, _items: List[OrderByItemNode], **kwargs):
124235
super().__init__(NodeType.ORDER_BY, children=_items, **kwargs)
125236

126237

@@ -129,13 +240,31 @@ class LimitNode(Node):
129240
def __init__(self, _limit: int, **kwargs):
130241
super().__init__(NodeType.LIMIT, **kwargs)
131242
self.limit = _limit
243+
244+
def __eq__(self, other):
245+
if not isinstance(other, LimitNode):
246+
return False
247+
return (super().__eq__(other) and
248+
self.limit == other.limit)
249+
250+
def __hash__(self):
251+
return hash((super().__hash__(), self.limit))
132252

133253

134254
class OffsetNode(Node):
135255
"""OFFSET clause node"""
136256
def __init__(self, _offset: int, **kwargs):
137257
super().__init__(NodeType.OFFSET, **kwargs)
138258
self.offset = _offset
259+
260+
def __eq__(self, other):
261+
if not isinstance(other, OffsetNode):
262+
return False
263+
return (super().__eq__(other) and
264+
self.offset == other.offset)
265+
266+
def __hash__(self):
267+
return hash((super().__hash__(), self.offset))
139268

140269

141270
class QueryNode(Node):

core/ast/node_type.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

tests/test_query_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
66
OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode
77
)
8-
from core.ast.node_type import NodeType
8+
from core.ast.enums import NodeType, JoinType, SortOrder
99
from data.queries import get_query
1010

1111
parser = QueryParser()

0 commit comments

Comments
 (0)