-
Notifications
You must be signed in to change notification settings - Fork 3
[Refactor] Add basic formatter #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
colinthebomb1
wants to merge
6
commits into
main
Choose a base branch
from
feature/add-basic-formatter
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
48e12fb
basic formatter
colinthebomb1 4fdaa97
update order by in test
colinthebomb1 50b3a40
create new files for formatter
colinthebomb1 15482c7
add join node support
colinthebomb1 e8fd425
Removed unnecesary imports and debug print lines
colinthebomb1 e0950f7
remove unused import
colinthebomb1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,265 @@ | ||
| import mo_sql_parsing as mosql | ||
| from core.ast.node import QueryNode | ||
| from core.ast.node import ( | ||
| QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, | ||
| LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, | ||
| OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, | ||
colinthebomb1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| JoinNode | ||
| ) | ||
| from core.ast.enums import NodeType, JoinType, SortOrder | ||
| from core.ast.node import Node | ||
|
|
||
| class QueryFormatter: | ||
| def format(self, query: QueryNode) -> str: | ||
| # [1] AST (QueryNode) -> JSON | ||
| json_query = ast_to_json(query) | ||
|
|
||
| # [2] Any (JSON) -> str | ||
| sql = mosql.format(json_query) | ||
|
|
||
| return sql | ||
|
|
||
| def ast_to_json(node: QueryNode) -> dict: | ||
| """Convert QueryNode AST to JSON dictionary for mosql""" | ||
| result = {} | ||
|
|
||
| # process each clause in the query | ||
| for child in node.children: | ||
| if child.type == NodeType.SELECT: | ||
| result['select'] = format_select(child) | ||
| elif child.type == NodeType.FROM: | ||
| result['from'] = format_from(child) | ||
| elif child.type == NodeType.WHERE: | ||
| result['where'] = format_where(child) | ||
| elif child.type == NodeType.GROUP_BY: | ||
| result['groupby'] = format_group_by(child) | ||
| elif child.type == NodeType.HAVING: | ||
| result['having'] = format_having(child) | ||
| elif child.type == NodeType.ORDER_BY: | ||
| result['orderby'] = format_order_by(child) | ||
| elif child.type == NodeType.LIMIT: | ||
| result['limit'] = child.limit | ||
| elif child.type == NodeType.OFFSET: | ||
| result['offset'] = child.offset | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def format_select(select_node: SelectNode) -> list: | ||
| """Format SELECT clause""" | ||
| items = [] | ||
|
|
||
| for child in select_node.children: | ||
| if child.type == NodeType.COLUMN: | ||
| if child.alias: | ||
| items.append({'name': child.alias, 'value': format_expression(child)}) | ||
| else: | ||
| items.append({'value': format_expression(child)}) | ||
| elif child.type == NodeType.FUNCTION: | ||
| func_expr = format_expression(child) | ||
| if hasattr(child, 'alias') and child.alias: | ||
| items.append({'name': child.alias, 'value': func_expr}) | ||
| else: | ||
| items.append({'value': func_expr}) | ||
| else: | ||
| items.append({'value': format_expression(child)}) | ||
|
|
||
| return items | ||
|
|
||
|
|
||
| def format_from(from_node: FromNode) -> list: | ||
| """Format FROM clause with explicit JOIN support""" | ||
| sources = [] | ||
| children = list(from_node.children) | ||
|
|
||
| if not children: | ||
| return sources | ||
|
|
||
| # Process JoinNode structure | ||
| for child in children: | ||
| if child.type == NodeType.JOIN: | ||
| join_sources = format_join(child) | ||
| # format_join returns a list, extend sources with it | ||
| if isinstance(join_sources, list): | ||
| sources.extend(join_sources) | ||
| else: | ||
| sources.append(join_sources) | ||
| elif child.type == NodeType.TABLE: | ||
| sources.append(format_table(child)) | ||
|
|
||
| return sources | ||
|
|
||
|
|
||
| def format_join(join_node: JoinNode) -> list: | ||
| """Format a JOIN node""" | ||
| children = list(join_node.children) | ||
|
|
||
| if len(children) < 2: | ||
| raise ValueError("JoinNode must have at least 2 children (left and right tables)") | ||
|
|
||
| left_node = children[0] | ||
| right_node = children[1] | ||
| join_condition = children[2] if len(children) > 2 else None | ||
|
|
||
| result = [] | ||
|
|
||
| # Format left side (could be a table or nested join) | ||
| if left_node.type == NodeType.JOIN: | ||
| # Nested join - recursively format | ||
| result.extend(format_join(left_node)) | ||
| elif left_node.type == NodeType.TABLE: | ||
| # Simple table - this becomes the FROM table | ||
| result.append(format_table(left_node)) | ||
|
|
||
| # Format the join itself | ||
| join_dict = {} | ||
|
|
||
| # Map join types to mosql format | ||
| join_type_map = { | ||
| JoinType.INNER: 'join', | ||
| JoinType.LEFT: 'left join', | ||
| JoinType.RIGHT: 'right join', | ||
| JoinType.FULL: 'full join', | ||
| JoinType.CROSS: 'cross join', | ||
| } | ||
|
|
||
| join_key = join_type_map.get(join_node.join_type, 'join') | ||
| join_dict[join_key] = format_table(right_node) | ||
|
|
||
| # Add join condition if it exists | ||
| if join_condition: | ||
| join_dict['on'] = format_expression(join_condition) | ||
|
|
||
| result.append(join_dict) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def format_table(table_node: TableNode) -> dict: | ||
| """Format a table reference""" | ||
| result = {'value': table_node.name} | ||
| if table_node.alias: | ||
| result['name'] = table_node.alias | ||
| return result | ||
baiqiushi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def format_where(where_node: WhereNode) -> dict: | ||
| """Format WHERE clause""" | ||
| predicates = list(where_node.children) | ||
| if len(predicates) == 1: | ||
| return format_expression(predicates[0]) | ||
| else: | ||
| return {'and': [format_expression(p) for p in predicates]} | ||
|
|
||
|
|
||
| def format_group_by(group_by_node: GroupByNode) -> list: | ||
| """Format GROUP BY clause""" | ||
| return [{'value': format_expression(child)} | ||
| for child in group_by_node.children] | ||
|
|
||
|
|
||
| def format_having(having_node: HavingNode) -> dict: | ||
| """Format HAVING clause""" | ||
| predicates = list(having_node.children) | ||
| if len(predicates) == 1: | ||
| return format_expression(predicates[0]) | ||
| else: | ||
| return {'and': [format_expression(p) for p in predicates]} | ||
|
|
||
|
|
||
| def format_order_by(order_by_node: OrderByNode) -> list: | ||
| """Format ORDER BY clause items.""" | ||
| items = [] | ||
|
|
||
| # get all items and their sort orders | ||
| sort_orders = [] | ||
| for child in order_by_node.children: | ||
| if child.type == NodeType.ORDER_BY_ITEM: | ||
| column = list(child.children)[0] | ||
|
|
||
| # Check if the column has an alias | ||
| if hasattr(column, 'alias') and column.alias: | ||
| item = {'value': column.alias} | ||
| else: | ||
| item = {'value': format_expression(column)} | ||
|
|
||
| sort_order = child.sort | ||
| sort_orders.append(sort_order) | ||
| else: | ||
| # Direct column reference (no OrderByItemNode wrapper) | ||
| if hasattr(child, 'alias') and child.alias: | ||
| item = {'value': child.alias} | ||
| else: | ||
| item = {'value': format_expression(child)} | ||
|
|
||
| sort_order = SortOrder.ASC | ||
| sort_orders.append(sort_order) | ||
|
|
||
| items.append((item, sort_order)) | ||
|
|
||
| # check if all sort orders are the same | ||
| all_same = len(set(sort_orders)) == 1 | ||
| common_sort = sort_orders[0] if all_same else None | ||
|
|
||
| # reformat into single sort operator if all items have same sort operator | ||
| # ex. ORDER BY dept_name DESC, emp_count DESC -> ORDER BY dept_name, emp_count DESC | ||
| result = [] | ||
| for i, (item, sort_order) in enumerate(items): | ||
| if all_same and i == len(items) - 1: | ||
| if common_sort != SortOrder.ASC: | ||
| item['sort'] = common_sort.value.lower() | ||
| elif not all_same: | ||
| if sort_order != SortOrder.ASC: | ||
| item['sort'] = sort_order.value.lower() | ||
|
|
||
| result.append(item) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def format_expression(node: Node): | ||
| """Format an expression node""" | ||
baiqiushi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if node.type == NodeType.COLUMN: | ||
| if node.parent_alias: | ||
| return f"{node.parent_alias}.{node.name}" | ||
| return node.name | ||
|
|
||
| elif node.type == NodeType.LITERAL: | ||
| return node.value | ||
|
|
||
| elif node.type == NodeType.FUNCTION: | ||
| # format: {'function_name': args} | ||
| func_name = node.name.lower() | ||
| args = [format_expression(arg) for arg in node.children] | ||
colinthebomb1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return {func_name: args[0] if len(args) == 1 else args} | ||
|
|
||
| elif node.type == NodeType.OPERATOR: | ||
| # format: {'operator': [left, right]} | ||
| op_map = { | ||
| '>': 'gt', | ||
| '<': 'lt', | ||
| '>=': 'gte', | ||
| '<=': 'lte', | ||
| '=': 'eq', | ||
| '!=': 'ne', | ||
| 'AND': 'and', | ||
| 'OR': 'or', | ||
| } | ||
|
|
||
| op_name = op_map.get(node.name.upper(), node.name.lower()) | ||
| children = list(node.children) | ||
|
|
||
| left = format_expression(children[0]) | ||
|
|
||
| if len(children) == 2: | ||
| right = format_expression(children[1]) | ||
| return {op_name: [left, right]} | ||
| else: | ||
| # unary operator | ||
| return {op_name: left} | ||
|
|
||
| elif node.type == NodeType.TABLE: | ||
| return format_table(node) | ||
|
|
||
| else: | ||
| raise ValueError(f"Unsupported node type in expression: {node.type}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| import mo_sql_parsing as mosql | ||
| from core.query_formatter import QueryFormatter | ||
| from core.ast.node import ( | ||
| OrderByItemNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, | ||
| LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, | ||
| OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode | ||
colinthebomb1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| from core.ast.enums import NodeType, JoinType, SortOrder | ||
colinthebomb1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from data.queries import get_query | ||
colinthebomb1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from re import sub | ||
|
|
||
| formatter = QueryFormatter() | ||
|
|
||
| def normalize_sql(s): | ||
| """Remove extra whitespace and normalize SQL string to be used in comparisons""" | ||
| s = s.strip() | ||
| s = sub(r'\s+', ' ', s) | ||
|
|
||
| return s | ||
|
|
||
| def test_basic_format(): | ||
| # Construct expected AST | ||
| # Tables | ||
| emp_table = TableNode("employees", "e") | ||
| dept_table = TableNode("departments", "d") | ||
| # Columns | ||
| emp_name = ColumnNode("name", _parent_alias="e") | ||
| emp_salary = ColumnNode("salary", _parent_alias="e") | ||
| emp_age = ColumnNode("age", _parent_alias="e") | ||
| emp_dept_id = ColumnNode("department_id", _parent_alias="e") | ||
|
|
||
| dept_name = ColumnNode("name", _alias="dept_name", _parent_alias="d") | ||
| dept_id = ColumnNode("id", _parent_alias="d") | ||
|
|
||
| count_star = FunctionNode("COUNT", _alias="emp_count", _args=[ColumnNode("*")]) | ||
|
|
||
| # SELECT clause | ||
| select_clause = SelectNode([emp_name, dept_name, count_star]) | ||
| # FROM clause with JOIN | ||
| join_condition = OperatorNode(emp_dept_id, "=", dept_id) | ||
| join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition) | ||
| from_clause = FromNode([join_node]) | ||
| # WHERE clause | ||
| salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) | ||
| age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) | ||
| where_condition = OperatorNode(salary_condition, "AND", age_condition) | ||
| where_clause = WhereNode([where_condition]) | ||
| # GROUP BY clause | ||
| group_by_clause = GroupByNode([dept_id, dept_name]) | ||
| # HAVING clause | ||
| having_condition = OperatorNode(count_star, ">", LiteralNode(2)) | ||
| having_clause = HavingNode([having_condition]) | ||
| # ORDER BY clause | ||
| order_by_item1 = OrderByItemNode(dept_name, SortOrder.ASC) | ||
| order_by_item2 = OrderByItemNode(count_star, SortOrder.DESC) | ||
| order_by_clause = OrderByNode([order_by_item1, order_by_item2]) | ||
| # LIMIT and OFFSET | ||
| limit_clause = LimitNode(10) | ||
| offset_clause = OffsetNode(5) | ||
| # Complete query | ||
| ast = QueryNode( | ||
| _select=select_clause, | ||
| _from=from_clause, | ||
| _where=where_clause, | ||
| _group_by=group_by_clause, | ||
| _having=having_clause, | ||
| _order_by=order_by_clause, | ||
| _limit=limit_clause, | ||
| _offset=offset_clause | ||
| ) | ||
|
|
||
| # Construct expected query text | ||
| expected_sql = """ | ||
| SELECT e.name, d.name AS dept_name, COUNT(*) AS emp_count | ||
| FROM employees AS e JOIN departments AS d ON e.department_id = d.id | ||
| WHERE e.salary > 40000 AND e.age < 60 | ||
| GROUP BY d.id, d.name | ||
| HAVING COUNT(*) > 2 | ||
| ORDER BY dept_name, emp_count DESC | ||
| LIMIT 10 OFFSET 5 | ||
| """ | ||
| expected_sql = expected_sql.strip() | ||
| print(mosql.parse(expected_sql)) | ||
| print(ast) | ||
baiqiushi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| sql = formatter.format(ast) | ||
| sql = sql.strip() | ||
|
|
||
| assert normalize_sql(sql) == normalize_sql(expected_sql) | ||
colinthebomb1 marked this conversation as resolved.
Show resolved
Hide resolved
colinthebomb1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.