Skip to content

Commit 1f55da3

Browse files
authored
refactor: replace a few functions in sql.py with AST. (#784)
* refactor: update Select, update sql.py * update test * update test * update test
1 parent 0433a1c commit 1f55da3

File tree

8 files changed

+151
-89
lines changed

8 files changed

+151
-89
lines changed

bigframes/core/blocks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import bigframes.constants
4141
import bigframes.constants as constants
4242
import bigframes.core as core
43+
import bigframes.core.compile.googlesql as googlesql
4344
import bigframes.core.expression as ex
4445
import bigframes.core.expression as scalars
4546
import bigframes.core.guid as guid
@@ -2417,7 +2418,9 @@ def _get_rows_as_json_values(self) -> Block:
24172418
select_columns = (
24182419
[ordering_column_name] + list(self.index_columns) + [row_json_column_name]
24192420
)
2420-
select_columns_csv = sql.csv([sql.identifier(col) for col in select_columns])
2421+
select_columns_csv = sql.csv(
2422+
[googlesql.identifier(col) for col in select_columns]
2423+
)
24212424
json_sql = f"""\
24222425
With T0 AS (
24232426
{textwrap.indent(expr_sql, " ")}
@@ -2430,7 +2433,7 @@ def _get_rows_as_json_values(self) -> Block:
24302433
"values", [{column_references_csv}],
24312434
"indexlength", {index_columns_count},
24322435
"dtype", {pandas_row_dtype}
2433-
) AS {sql.identifier(row_json_column_name)} FROM T0
2436+
) AS {googlesql.identifier(row_json_column_name)} FROM T0
24342437
)
24352438
SELECT {select_columns_csv} FROM T1
24362439
"""

bigframes/core/compile/compiled.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pandas
2929

3030
import bigframes.core.compile.aggregate_compiler as agg_compiler
31+
import bigframes.core.compile.googlesql
3132
import bigframes.core.compile.ibis_types
3233
import bigframes.core.compile.scalar_op_compiler as op_compilers
3334
import bigframes.core.expression as ex
@@ -905,7 +906,12 @@ def to_sql(
905906
output_columns = [
906907
col_id_overrides.get(col, col) for col in baked_ir.column_ids
907908
]
908-
sql = bigframes.core.sql.select_from_subquery(output_columns, sql)
909+
sql = (
910+
bigframes.core.compile.googlesql.Select()
911+
.from_(sql)
912+
.select(output_columns)
913+
.sql()
914+
)
909915

910916
# Single row frames may not have any ordering columns
911917
if len(baked_ir._ordering.all_ordering_columns) > 0:

bigframes/core/compile/googlesql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AliasExpression,
2323
ColumnExpression,
2424
CTEExpression,
25+
identifier,
2526
StarExpression,
2627
TableExpression,
2728
)
@@ -38,6 +39,7 @@
3839

3940
__all__ = [
4041
"_escape_chars",
42+
"identifier",
4143
"AliasExpression",
4244
"AsAlias",
4345
"ColumnExpression",

bigframes/core/compile/googlesql/expression.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class ColumnExpression(Expression):
4545

4646
def sql(self) -> str:
4747
if self.parent is not None:
48-
return f"{self.parent.sql()}.`{self.name}`"
49-
return f"`{self.name}`"
48+
return f"{self.parent.sql()}.{identifier(self.name)}"
49+
return identifier(self.name)
5050

5151

5252
@dataclasses.dataclass
@@ -72,10 +72,10 @@ def __post_init__(self):
7272
def sql(self) -> str:
7373
text = []
7474
if self.project_id is not None:
75-
text.append(f"`{_escape_chars(self.project_id)}`")
75+
text.append(identifier(self.project_id))
7676
if self.dataset_id is not None:
77-
text.append(f"`{_escape_chars(self.dataset_id)}`")
78-
text.append(f"`{_escape_chars(self.table_id)}`")
77+
text.append(identifier(self.dataset_id))
78+
text.append(identifier(self.table_id))
7979
return ".".join(text)
8080

8181

@@ -84,15 +84,22 @@ class AliasExpression(Expression):
8484
alias: str
8585

8686
def sql(self) -> str:
87-
return f"`{_escape_chars(self.alias)}`"
87+
return identifier(self.alias)
8888

8989

9090
@dataclasses.dataclass
9191
class CTEExpression(Expression):
9292
name: str
9393

9494
def sql(self) -> str:
95-
return f"`{_escape_chars(self.name)}`"
95+
return identifier(self.name)
96+
97+
98+
def identifier(id: str) -> str:
99+
"""Return a string representing column reference in a SQL."""
100+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers
101+
# Just always escape, otherwise need to check against every reserved sql keyword
102+
return f"`{_escape_chars(id)}`"
96103

97104

98105
def _escape_chars(value: str):

bigframes/core/compile/googlesql/query.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616

1717
import dataclasses
1818
import typing
19-
from typing import TYPE_CHECKING
19+
20+
import google.cloud.bigquery as bigquery
2021

2122
import bigframes.core.compile.googlesql.abc as abc
2223
import bigframes.core.compile.googlesql.expression as expr
2324

24-
if TYPE_CHECKING:
25-
import google.cloud.bigquery as bigquery
26-
2725
"""This module provides a structured representation of GoogleSQL syntax using nodes.
2826
Each node's name and child nodes are designed to strictly follow the official GoogleSQL
2927
syntax rules outlined in the documentation:
3028
https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax"""
3129

30+
TABLE_SOURCE_TYPE = typing.Union[str, bigquery.TableReference]
31+
3232

3333
@dataclasses.dataclass
3434
class QueryExpr(abc.SQLSyntax):
@@ -53,11 +53,47 @@ def sql(self) -> str:
5353
class Select(abc.SQLSyntax):
5454
"""This class represents GoogleSQL `select` syntax."""
5555

56-
select_list: typing.Sequence[typing.Union[SelectExpression, SelectAll]]
57-
from_clause_list: typing.Sequence[FromClause] = ()
56+
select_list: typing.Sequence[
57+
typing.Union[SelectExpression, SelectAll]
58+
] = dataclasses.field(default_factory=list)
59+
from_clause_list: typing.Sequence[FromClause] = dataclasses.field(
60+
default_factory=list
61+
)
5862
distinct: bool = False
5963

64+
def select(
65+
self,
66+
columns: typing.Union[typing.Iterable[str], str, None] = None,
67+
distinct: bool = False,
68+
) -> Select:
69+
if isinstance(columns, str):
70+
columns = [columns]
71+
self.select_list: typing.List[typing.Union[SelectExpression, SelectAll]] = (
72+
[
73+
SelectExpression(expression=expr.ColumnExpression(name=column))
74+
for column in columns
75+
]
76+
if columns
77+
else [SelectAll(expression=expr.StarExpression())]
78+
)
79+
self.distinct = distinct
80+
return self
81+
82+
def from_(
83+
self,
84+
sources: typing.Union[TABLE_SOURCE_TYPE, typing.Iterable[TABLE_SOURCE_TYPE]],
85+
) -> Select:
86+
if (not isinstance(sources, typing.Iterable)) or isinstance(sources, str):
87+
sources = [sources]
88+
self.from_clause_list = [
89+
FromClause(FromItem.from_source(source)) for source in sources
90+
]
91+
return self
92+
6093
def sql(self) -> str:
94+
if (self.select_list is not None) and (not self.select_list):
95+
raise ValueError("Select clause has not been properly initialized.")
96+
6197
text = ["SELECT"]
6298

6399
if self.distinct:
@@ -66,7 +102,7 @@ def sql(self) -> str:
66102
select_list_sql = ",\n".join([select.sql() for select in self.select_list])
67103
text.append(select_list_sql)
68104

69-
if self.from_clause_list is not None:
105+
if self.from_clause_list:
70106
from_clauses_sql = ",\n".join(
71107
[clause.sql() for clause in self.from_clause_list]
72108
)
@@ -118,19 +154,27 @@ class FromItem(abc.SQLSyntax):
118154
as_alias: typing.Optional[AsAlias] = None
119155

120156
@classmethod
121-
def from_table_ref(
157+
def from_source(
122158
cls,
123-
table_ref: bigquery.TableReference,
159+
subquery_or_tableref: typing.Union[bigquery.TableReference, str],
124160
as_alias: typing.Optional[AsAlias] = None,
125161
):
126-
return cls(
127-
expression=expr.TableExpression(
128-
table_id=table_ref.table_id,
129-
dataset_id=table_ref.dataset_id,
130-
project_id=table_ref.project,
131-
),
132-
as_alias=as_alias,
133-
)
162+
if isinstance(subquery_or_tableref, bigquery.TableReference):
163+
return cls(
164+
expression=expr.TableExpression(
165+
table_id=subquery_or_tableref.table_id,
166+
dataset_id=subquery_or_tableref.dataset_id,
167+
project_id=subquery_or_tableref.project,
168+
),
169+
as_alias=as_alias,
170+
)
171+
elif isinstance(subquery_or_tableref, str):
172+
return cls(
173+
expression=subquery_or_tableref,
174+
as_alias=as_alias,
175+
)
176+
else:
177+
raise ValueError("The source must be bigquery.TableReference or str.")
134178

135179
def sql(self) -> str:
136180
if isinstance(self.expression, (expr.TableExpression, expr.CTEExpression)):

bigframes/core/sql.py

Lines changed: 4 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323

2424
import bigframes.core.compile.googlesql as googlesql
2525

26-
# Literals and identifiers matching this pattern can be unquoted
27-
unquoted = r"^[A-Za-z_][A-Za-z_0-9]*$"
28-
29-
3026
if TYPE_CHECKING:
3127
import google.cloud.bigquery as bigquery
3228

@@ -62,23 +58,16 @@ def multi_literal(*values: str):
6258
return "(" + ", ".join(literal_strings) + ")"
6359

6460

65-
def identifier(id: str) -> str:
66-
"""Return a string representing column reference in a SQL."""
67-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers
68-
# Just always escape, otherwise need to check against every reserved sql keyword
69-
return f"`{googlesql._escape_chars(id)}`"
70-
71-
7261
def cast_as_string(column_name: str) -> str:
7362
"""Return a string representing string casting of a column."""
7463

75-
return f"CAST({identifier(column_name)} AS STRING)"
64+
return f"CAST({googlesql.identifier(column_name)} AS STRING)"
7665

7766

7867
def to_json_string(column_name: str) -> str:
7968
"""Return a string representing JSON version of a column."""
8069

81-
return f"TO_JSON_STRING({identifier(column_name)})"
70+
return f"TO_JSON_STRING({googlesql.identifier(column_name)})"
8271

8372

8473
def csv(values: Iterable[str]) -> str:
@@ -91,55 +80,12 @@ def infix_op(opname: str, left_arg: str, right_arg: str):
9180
return f"{left_arg} {opname} {right_arg}"
9281

9382

94-
### Writing SELECT expressions
95-
def select_from_subquery(columns: Iterable[str], subquery: str, distinct: bool = False):
96-
select_list = [
97-
googlesql.SelectExpression(expression=googlesql.ColumnExpression(name=column))
98-
for column in columns
99-
]
100-
from_clause_list = [googlesql.FromClause(googlesql.FromItem(expression=subquery))]
101-
102-
select_expr = googlesql.Select(
103-
select_list=select_list, from_clause_list=from_clause_list, distinct=distinct
104-
)
105-
return select_expr.sql()
106-
107-
108-
def select_from_table_ref(
109-
columns: Iterable[str], table_ref: bigquery.TableReference, distinct: bool = False
110-
):
111-
select_list = [
112-
googlesql.SelectExpression(expression=googlesql.ColumnExpression(name=column))
113-
for column in columns
114-
]
115-
from_clause_list = [
116-
googlesql.FromClause(googlesql.FromItem.from_table_ref(table_ref))
117-
]
118-
119-
select_expr = googlesql.Select(
120-
select_list=select_list, from_clause_list=from_clause_list, distinct=distinct
121-
)
122-
return select_expr.sql()
123-
124-
125-
def select_table(table_ref: bigquery.TableReference):
126-
select_list = [googlesql.SelectAll(expression=googlesql.StarExpression())]
127-
from_clause_list = [
128-
googlesql.FromClause(googlesql.FromItem.from_table_ref(table_ref))
129-
]
130-
131-
select_expr = googlesql.Select(
132-
select_list=select_list, from_clause_list=from_clause_list
133-
)
134-
return select_expr.sql()
135-
136-
13783
def is_distinct_sql(columns: Iterable[str], table_ref: bigquery.TableReference) -> str:
13884
is_unique_sql = f"""WITH full_table AS (
139-
{select_from_table_ref(columns, table_ref)}
85+
{googlesql.Select().from_(table_ref).select(columns).sql()}
14086
),
14187
distinct_table AS (
142-
{select_from_table_ref(columns, table_ref, distinct=True)}
88+
{googlesql.Select().from_(table_ref).select(columns, distinct=True).sql()}
14389
)
14490
14591
SELECT (SELECT COUNT(*) FROM full_table) AS `total_count`,

bigframes/session/_io/bigquery/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import bigframes
3434
from bigframes.core import log_adapter
35+
import bigframes.core.compile.googlesql as googlesql
3536
import bigframes.core.sql
3637
import bigframes.formatting_helpers as formatting_helpers
3738

@@ -480,7 +481,7 @@ def compile_filters(filters: third_party_pandas_gbq.FiltersType) -> str:
480481

481482
operator_str = valid_operators[operator]
482483

483-
column_ref = bigframes.core.sql.identifier(column)
484+
column_ref = googlesql.identifier(column)
484485
if operator_str in ["IN", "NOT IN"]:
485486
value_literal = bigframes.core.sql.multi_literal(*value)
486487
else:

0 commit comments

Comments
 (0)