Skip to content

Commit eeee1c6

Browse files
committed
fix bigquery -> psql transpilation of GENERATE_ARRAY
1 parent ae7ae9a commit eeee1c6

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

src/mimic_utils/sqlglot_dialects/postgres.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sqlglot
22
import sqlglot.dialects.postgres
3-
from sqlglot import Expression, exp, select
3+
from sqlglot import Expression, exp
4+
from sqlglot.expressions import array, select
45

56
# DATETIME: allow passing either a DATE directly, or multiple arguments
67
# there isn't a class for the Datetime function, so we have to create it ourself,
@@ -94,18 +95,20 @@ def datetime_sql(self: Expression, expression: Expression):
9495
# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_array
9596
# https://www.postgresql.org/docs/current/functions-srf.html
9697
def generate_array_sql(self: Expression, expression: Expression):
97-
# first create a select statement which selects from generate_series
98-
select_statement = select("*").from_(
98+
# BigQuery's generate array returns an array data type,
99+
# but PostgreSQL generate series returns a set of rows,
100+
# so we wrap the output of generate series in an array
101+
# constructor.
102+
select_statement = array(select("*").from_(
99103
GenerateSeries(
100104
expressions=[
101105
expression.expressions[0],
102106
expression.expressions[1],
103107
],
104108
)
105-
)
109+
))
106110

107-
# now convert the select statement to an array
108-
return f"ARRAY({self.sql(select_statement)})"
111+
return self.generate(select_statement)
109112
sqlglot.dialects.postgres.Postgres.Generator.TRANSFORMS[GenerateArray] = generate_array_sql
110113

111114
# we need to prevent the wrapping of the table alias in brackets for UNNEST

src/mimic_utils/transpile.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import sqlglot.dialects.bigquery
77
import sqlglot.dialects.duckdb
88
import sqlglot.dialects.postgres
9-
from sqlglot import Expression, exp, select
10-
from sqlglot.helper import seq_get
9+
from sqlglot import exp, select, alias
10+
from sqlglot.expressions import Array, array
1111

1212
# Apply transformation monkey patches
1313
# these modules are imported for their side effects
@@ -47,6 +47,26 @@ def transpile_query(query: str, source_dialect: str="bigquery", destination_dial
4747
quoted=False
4848
)
4949

50+
# HACK: sqlglot has a GenerateSeries transpilation in v25.13.0,
51+
# which is inserted during the parse of BigQuery. However, it looks
52+
# incorrect for postgres (at least), as it swaps GENERATE_ARRAY for GENERATE_SERIES.
53+
# BigQuery's GENERATE_ARRAY outputs an array, but GENERATE_SERIES outputs exploded rows.
54+
# We will manually replace the GENERATE_SERIES call with an anonymous function, so our
55+
# custom transpile code can do the correct conversion for postgres.
56+
if (source_dialect == 'bigquery') and (destination_dialect == 'postgres'):
57+
for gs_function in sql_parsed.find_all(exp.GenerateSeries):
58+
# rename to our anonymous generate array function, so the
59+
# later loop will catch it
60+
gs_function.replace(
61+
exp.Anonymous(
62+
this='GENERATE_ARRAY',
63+
expressions=[
64+
gs_function.args['start'],
65+
gs_function.args['end']
66+
]
67+
)
68+
)
69+
5070
# BigQuery has a few functions which are not in sqlglot, so we have
5171
# created classes for them, and this loop replaces the anonymous functions
5272
# with the named functions

0 commit comments

Comments
 (0)