Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 342 additions & 11 deletions core/query_parser.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,354 @@
from core.ast.node import QueryNode
from core.ast.node import (
Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'SubqueryNode' is not used.
Import of 'VarNode' is not used.
Import of 'VarSetNode' is not used.

Suggested change
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed this version of implementation does not consider subquery. We will consider it in the next iteration.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, in this case I will leave a TODO comment above as reminders.

)
from core.ast.enums import JoinType, SortOrder
import mo_sql_parsing as mosql

class QueryParser:
@staticmethod
def normalize_to_list(value):
"""Normalize mo_sql_parsing output to a list format.

mo_sql_parsing returns:
- list when multiple items
- dict when single item with structure
- str when single simple value

This normalizes all cases to a list.
"""
if value is None:
return []
elif isinstance(value, list):
return value
elif isinstance(value, (dict, str)):
return [value]
else:
return [value]

def parse(self, query: str) -> QueryNode:
# Implement parsing logic using self.rules
pass

# [1] Call mo_sql_parser
# str -> Any (JSON)
mosql_ast = mosql.parse(query)

# [2] Our new code
# Any (JSON) -> AST (QueryNode)
self.aliases = {}
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.aliases instance variable is initialized within the parse() method rather than in __init__. This means the aliases dictionary persists between parse calls and could lead to stale alias references affecting subsequent parses. Consider either: 1) initializing self.aliases = {} in a proper __init__ method, or 2) passing aliases as a parameter through the helper methods instead of using instance state.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point which I also noticed. If we use the instance member to store aliases, this instance is not safe for multi-threading. [1] One way to solve this issue is through a helper parameter being passed through call the following function calls through the function arguments, similar to what we use as the memo in the rewriter engine. [2] Or we can make sure the QueryParser is declared as non-thread-safe, and every place we need to use the parse function, we create a new instance of the parser, i.e., parser = QueryParser(), and if we use this approach, it makes more sense to init the self.aliases = {} in the __init__ function instead of this parse function. @HazelYuAhiru , please make a decision with coordination with @colinthebomb1 's PR, if he uses the same approach (some internal state shared by different functions as an instance-level state), we can use [2]. If he already used [1], you may change to use [1]. If you decide to use [2], please find a proper annotation tag in Python to say it is non-thread-safe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'v updated the parser to use a thread-safe approach


select_clause = None
from_clause = None
where_clause = None
group_by_clause = None
having_clause = None
order_by_clause = None
limit_clause = None
offset_clause = None

if 'select' in mosql_ast:
select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select']))
if 'from' in mosql_ast:
from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from']))
if 'where' in mosql_ast:
where_clause = self.parse_where(mosql_ast['where'])
if 'groupby' in mosql_ast:
group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby']))
if 'having' in mosql_ast:
having_clause = self.parse_having(mosql_ast['having'])
if 'orderby' in mosql_ast:
order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby']))
if 'limit' in mosql_ast:
limit_clause = LimitNode(mosql_ast['limit'])
if 'offset' in mosql_ast:
offset_clause = OffsetNode(mosql_ast['offset'])

return QueryNode(
_select=select_clause,
_from=from_clause,
_where=where_clause,
_group_by=group_by_clause,
_having=having_clause,
_order_by=order_by_clause,
_limit=limit_clause,
_offset=offset_clause
)

def parse_select(self, select_list: list) -> SelectNode:
items = set()
for item in select_list:
if isinstance(item, dict) and 'value' in item:
expression = self.parse_expression(item['value'])
# Handle alias - set for any node that has alias attribute
if 'name' in item:
alias = item['name']
if hasattr(expression, 'alias'):
expression.alias = alias
self.aliases[alias] = expression

items.add(expression)
else:
# Handle direct expression (string, int, etc.)
expression = self.parse_expression(item)
items.add(expression)

return SelectNode(items)

def parse_from(self, from_list: list) -> FromNode:
sources = set()
left_source = None # Can be a table or the result of a previous join

for item in from_list:
# Check for JOIN first (before checking for 'value')
if isinstance(item, dict):
# Look for any join key
join_key = next((k for k in item.keys() if 'join' in k.lower()), None)

if join_key:
# This is a JOIN
if left_source is None:
raise ValueError("JOIN found without a left table")

join_info = item[join_key]
# Handle both string and dict join_info
if isinstance(join_info, str):
table_name = join_info
alias = None
else:
table_name = join_info['value'] if isinstance(join_info, dict) else join_info
alias = join_info.get('name') if isinstance(join_info, dict) else None

right_table = TableNode(table_name, alias)
# Track table alias
if alias:
self.aliases[alias] = right_table

on_condition = None
if 'on' in item:
on_condition = self.parse_expression(item['on'])

# Create join node - left_source might be a table or a previous join
join_type = self.parse_join_type(join_key)
join_node = JoinNode(left_source, right_table, join_type, on_condition)
# The result of this JOIN becomes the new left source for potential next JOIN
left_source = join_node
elif 'value' in item:
# This is a table reference
table_name = item['value']
alias = item.get('name')
table_node = TableNode(table_name, alias)
# Track table alias
if alias:
self.aliases[alias] = table_node

if left_source is None:
# First table becomes the left source
left_source = table_node
else:
# Multiple tables without explicit JOIN (cross join)
sources.add(table_node)
elif isinstance(item, str):
# Simple string table name
table_node = TableNode(item)
if left_source is None:
left_source = table_node
else:
sources.add(table_node)

# Add the final left source (which might be a single table or chain of joins)
if left_source is not None:
sources.add(left_source)

return FromNode(sources)

def parse_where(self, where_dict: dict) -> WhereNode:
predicates = set()
predicates.add(self.parse_expression(where_dict))
return WhereNode(predicates)

def parse_group_by(self, group_by_list: list) -> GroupByNode:
items = []
for item in group_by_list:
if isinstance(item, dict) and 'value' in item:
expr = self.parse_expression(item['value'])
# Resolve aliases
expr = self.resolve_aliases(expr)
items.append(expr)
else:
# Handle direct expression (string, int, etc.)
expr = self.parse_expression(item)
expr = self.resolve_aliases(expr)
items.append(expr)

def format(self, query: QueryNode) -> str:
# Implement formatting logic to convert AST back to SQL string
pass
return GroupByNode(items)

def parse_having(self, having_dict: dict) -> HavingNode:
predicates = set()
expr = self.parse_expression(having_dict)
# Check if this expression references an aliased function from SELECT
expr = self.resolve_aliases(expr)

predicates.add(expr)

# [1] Our new code
# AST (QueryNode) -> JSON
return HavingNode(predicates)

def parse_order_by(self, order_by_list: list) -> OrderByNode:
items = []
for item in order_by_list:
if isinstance(item, dict) and 'value' in item:
value = item['value']
# Check if this is an alias reference
if isinstance(value, str) and value in self.aliases:
column = self.aliases[value]
else:
# Parse normally for other cases
column = self.parse_expression(value)

# Get sort order (default is ASC)
sort_order = SortOrder.ASC
if 'sort' in item:
sort_str = item['sort'].upper()
if sort_str == 'DESC':
sort_order = SortOrder.DESC

# Wrap in OrderByItemNode
order_by_item = OrderByItemNode(column, sort_order)
items.append(order_by_item)
else:
# Handle direct expression (string, int, etc.)
column = self.parse_expression(item)
order_by_item = OrderByItemNode(column, SortOrder.ASC)
items.append(order_by_item)

# [2] Call mo_sql_format
# Any (JSON) -> str
return OrderByNode(items)

def resolve_aliases(self, expr: Node) -> Node:
if isinstance(expr, OperatorNode):
# Recursively resolve aliases in operator operands
left = self.resolve_aliases(expr.children[0])
right = self.resolve_aliases(expr.children[1])
return OperatorNode(left, expr.name, right)
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the OperatorNode is created with unary operators (e.g., 'NOT'), only one operand is passed (line 309). The resolve_aliases method assumes binary operators and always accesses children[0] and children[1] (lines 227-228). This will cause an IndexError when resolving aliases for unary operators. Add a check for the number of children before accessing indices.

Suggested change
left = self.resolve_aliases(expr.children[0])
right = self.resolve_aliases(expr.children[1])
return OperatorNode(left, expr.name, right)
if len(expr.children) == 1:
child = self.resolve_aliases(expr.children[0])
return OperatorNode(child, expr.name)
elif len(expr.children) == 2:
left = self.resolve_aliases(expr.children[0])
right = self.resolve_aliases(expr.children[1])
return OperatorNode(left, expr.name, right)
else:
# Unexpected number of children; return as is
return expr

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. But I doubt if we can resolve it in this PR since we don't have such test cases. Let's push this fix to later PRs where we introduce more test cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I added a TODO in this section as a reminder.

elif isinstance(expr, FunctionNode):
# Check if this function matches an aliased function from SELECT
if expr.alias is None:
for alias, aliased_expr in self.aliases.items():
if isinstance(aliased_expr, FunctionNode):
if (expr.name == aliased_expr.name and
len(expr.children) == len(aliased_expr.children) and
all(expr.children[i] == aliased_expr.children[i]
for i in range(len(expr.children)))):
# This function matches an aliased one, use the alias
expr.alias = alias
break
return expr
elif isinstance(expr, ColumnNode):
# Check if this column matches an aliased column from SELECT
if expr.alias is None:
for alias, aliased_expr in self.aliases.items():
if isinstance(aliased_expr, ColumnNode):
if (expr.name == aliased_expr.name and
expr.parent_alias == aliased_expr.parent_alias):
# This column matches an aliased one, use the alias
expr.alias = alias
break
return expr
else:
return expr
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The alias resolution in the HAVING clause iterates through all aliases to find matching FunctionNode or ColumnNode instances (lines 233-241, 246-252). For large SELECT clauses with many aliases, this could be inefficient. Consider optimizing by creating a separate lookup structure or only checking relevant aliases.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is OK for now.


def parse_expression(self, expr) -> Node:
if isinstance(expr, str):
# Column reference
if '.' in expr:
parts = expr.split('.', 1)
return ColumnNode(parts[1], _parent_alias=parts[0])
return ColumnNode(expr)

if isinstance(expr, (int, float, bool)):
return LiteralNode(expr)

if isinstance(expr, list):
# List literals (for IN clauses) - convert to tuple for hashability
parsed = tuple(self.parse_expression(item) for item in expr)
return LiteralNode(parsed)

if isinstance(expr, dict):
# Special cases first
if 'all_columns' in expr:
return ColumnNode('*')
if 'literal' in expr:
return LiteralNode(expr['literal'])

# Skip metadata keys
skip_keys = {'value', 'name', 'on', 'sort'}

# Find the operator/function key
for key in expr.keys():
if key in skip_keys:
continue

value = expr[key]
op_name = self.normalize_operator_name(key)

# Pattern 1: Binary/N-ary operator with list of operands
if isinstance(value, list):
if len(value) == 0:
return LiteralNode(None)
if len(value) == 1:
return self.parse_expression(value[0])

# Parse all operands
operands = [self.parse_expression(v) for v in value]

# Chain multiple operands with the same operator
result = operands[0]
for operand in operands[1:]:
result = OperatorNode(result, op_name, operand)
return result

# Pattern 2: Unary operator
if key == 'not':
return OperatorNode(self.parse_expression(value), 'NOT')

# Pattern 3: Function call
# Special case: COUNT(*), SUM(*), etc.
if value == '*':
return FunctionNode(op_name, [ColumnNode('*')])

# Regular function
args = [self.parse_expression(value)]
return FunctionNode(op_name, args)

# No valid key found
import json
return LiteralNode(json.dumps(expr, sort_keys=True))

# Other types
return LiteralNode(expr)

@staticmethod
def normalize_operator_name(key: str) -> str:
"""Convert mo_sql_parsing operator keys to SQL operator names."""
mapping = {
'eq': '=', 'neq': '!=', 'ne': '!=',
'gt': '>', 'gte': '>=',
'lt': '<', 'lte': '<=',
'and': 'AND', 'or': 'OR',
}
return mapping.get(key.lower(), key.upper())

@staticmethod
def parse_join_type(join_key: str) -> JoinType:
"""Extract JoinType from mo_sql_parsing join key."""
key_lower = join_key.lower().replace(' ', '_')

if 'inner' in key_lower:
return JoinType.INNER
elif 'left' in key_lower:
return JoinType.LEFT
elif 'right' in key_lower:
return JoinType.RIGHT
elif 'full' in key_lower:
return JoinType.FULL
elif 'cross' in key_lower:
return JoinType.CROSS

return JoinType.INNER # By default
Loading