Skip to content

Commit 36a3314

Browse files
authored
refactor: support all data types for readlocal compiler (#1666)
1 parent 9ac8135 commit 36a3314

File tree

11 files changed

+427
-55
lines changed

11 files changed

+427
-55
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,22 +146,17 @@ def _compile_node(
146146

147147
@_compile_node.register
148148
def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
149-
offsets = node.offsets_col.sql if node.offsets_col else None
150-
schema_names = node.schema.names
151-
schema_dtypes = node.schema.dtypes
152-
153149
pa_table = node.local_data_source.data
154150
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
155-
pa_table = pa_table.rename_columns(
156-
{item.source_id: item.id.sql for item in node.scan_list.items}
157-
)
151+
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
158152

153+
offsets = node.offsets_col.sql if node.offsets_col else None
159154
if offsets:
160155
pa_table = pa_table.append_column(
161156
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
162157
)
163158

164-
return ir.SQLGlotIR.from_pandas(pa_table.to_pandas(), schema_names, schema_dtypes)
159+
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema)
165160

166161

167162
@_compile_node.register

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,23 @@
1717
import dataclasses
1818
import typing
1919

20-
import pandas as pd
20+
import pyarrow as pa
2121
import sqlglot as sg
2222
import sqlglot.dialects.bigquery
2323
import sqlglot.expressions as sge
2424

2525
from bigframes import dtypes
2626
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
27+
import bigframes.core.local_data as local_data
28+
import bigframes.core.schema as schemata
29+
30+
# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
31+
try:
32+
from shapely.io import to_wkt # type: ignore
33+
except ImportError:
34+
from shapely.wkt import dumps # type: ignore
35+
36+
to_wkt = dumps
2737

2838

2939
@dataclasses.dataclass(frozen=True)
@@ -48,35 +58,32 @@ def sql(self) -> str:
4858
return self.expr.sql(dialect=self.dialect, pretty=self.pretty)
4959

5060
@classmethod
51-
def from_pandas(
52-
cls,
53-
pd_df: pd.DataFrame,
54-
schema_names: typing.Sequence[str],
55-
schema_dtypes: typing.Sequence[dtypes.Dtype],
61+
def from_pyarrow(
62+
cls, pa_table: pa.Table, schema: schemata.ArraySchema
5663
) -> SQLGlotIR:
5764
"""Builds SQLGlot expression from pyarrow table."""
5865
dtype_expr = sge.DataType(
5966
this=sge.DataType.Type.STRUCT,
6067
expressions=[
6168
sge.ColumnDef(
62-
this=sge.to_identifier(name, quoted=True),
63-
kind=sgt.SQLGlotType.from_bigframes_dtype(dtype),
69+
this=sge.to_identifier(field.column, quoted=True),
70+
kind=sgt.SQLGlotType.from_bigframes_dtype(field.dtype),
6471
)
65-
for name, dtype in zip(schema_names, schema_dtypes)
72+
for field in schema.items
6673
],
6774
nested=True,
6875
)
6976
data_expr = [
70-
sge.Tuple(
77+
sge.Struct(
7178
expressions=tuple(
7279
_literal(
7380
value=value,
74-
dtype=sgt.SQLGlotType.from_bigframes_dtype(dtype),
81+
dtype=field.dtype,
7582
)
76-
for value, dtype in zip(row, schema_dtypes)
83+
for value, field in zip(tuple(row_dict.values()), schema.items)
7784
)
7885
)
79-
for _, row in pd_df.iterrows()
86+
for row_dict in local_data._iter_table(pa_table, schema)
8087
]
8188
expr = sge.Unnest(
8289
expressions=[
@@ -105,13 +112,36 @@ def select(
105112
return SQLGlotIR(expr=expr)
106113

107114

108-
def _literal(value: typing.Any, dtype: str) -> sge.Expression:
115+
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
116+
sqlglot_type = sgt.SQLGlotType.from_bigframes_dtype(dtype)
109117
if value is None:
110-
return _cast(sge.Null(), dtype)
111-
112-
# TODO: handle other types like visit_DefaultLiteral
113-
return sge.convert(value)
118+
return _cast(sge.Null(), sqlglot_type)
119+
elif dtype == dtypes.BYTES_DTYPE:
120+
return _cast(str(value), sqlglot_type)
121+
elif dtypes.is_time_like(dtype):
122+
return _cast(sge.convert(value.isoformat()), sqlglot_type)
123+
elif dtypes.is_geo_like(dtype):
124+
wkt = value if isinstance(value, str) else to_wkt(value)
125+
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
126+
elif dtype == dtypes.JSON_DTYPE:
127+
return sge.ParseJSON(this=sge.convert(str(value)))
128+
elif dtypes.is_struct_like(dtype):
129+
items = [
130+
_literal(value=value[field_name], dtype=field_dtype).as_(
131+
field_name, quoted=True
132+
)
133+
for field_name, field_dtype in dtypes.get_struct_fields(dtype).items()
134+
]
135+
return sge.Struct.from_arg_list(items)
136+
elif dtypes.is_array_like(dtype):
137+
value_type = dtypes.get_array_inner_type(dtype)
138+
values = sge.Array(
139+
expressions=[_literal(value=v, dtype=value_type) for v in value]
140+
)
141+
return values if len(value) > 0 else _cast(values, sqlglot_type)
142+
else:
143+
return sge.convert(value)
114144

115145

116-
def _cast(arg, to) -> sge.Cast:
146+
def _cast(arg: typing.Any, to: str) -> sge.Cast:
117147
return sge.Cast(this=arg, to=to)

tests/data/scalars.jsonl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
{"bool_col": true, "bytes_col": "SGVsbG8sIFdvcmxkIQ==", "date_col": "2021-07-21", "datetime_col": "2021-07-21 11:39:45", "geography_col": "POINT(-122.0838511 37.3860517)", "int64_col": "123456789", "int64_too": "0", "numeric_col": "1.23456789", "float64_col": "1.25", "rowindex": 0, "rowindex_2": 0, "string_col": "Hello, World!", "time_col": "11:41:43.076160", "timestamp_col": "2021-07-21T17:43:43.945289Z"}
2-
{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "1991-02-03", "datetime_col": "1991-01-02 03:45:06", "geography_col": "POINT(-71.104 42.315)", "int64_col": "-987654321", "int64_too": "1", "numeric_col": "1.23456789", "float64_col": "2.51", "rowindex": 1, "rowindex_2": 1, "string_col": "こんにちは", "time_col": "11:14:34.701606", "timestamp_col": "2021-07-21T17:43:43.945289Z"}
3-
{"bool_col": true, "bytes_col": "wqFIb2xhIE11bmRvIQ==", "date_col": "2023-03-01", "datetime_col": "2023-03-01 10:55:13", "geography_col": "POINT(-0.124474760143016 51.5007826749545)", "int64_col": "314159", "int64_too": "0", "numeric_col": "101.1010101", "float64_col": "2.5e10", "rowindex": 2, "rowindex_2": 2, "string_col": " ¡Hola Mundo! ", "time_col": "23:59:59.999999", "timestamp_col": "2023-03-01T10:55:13.250125Z"}
4-
{"bool_col": null, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": null, "int64_too": "1", "numeric_col": null, "float64_col": null, "rowindex": 3, "rowindex_2": 3, "string_col": null, "time_col": null, "timestamp_col": null}
5-
{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "2021-07-21", "datetime_col": null, "geography_col": null, "int64_col": "-234892", "int64_too": "-2345", "numeric_col": null, "float64_col": null, "rowindex": 4, "rowindex_2": 4, "string_col": "Hello, World!", "time_col": null, "timestamp_col": null}
6-
{"bool_col": false, "bytes_col": "R8O8dGVuIFRhZw==", "date_col": "1980-03-14", "datetime_col": "1980-03-14 15:16:17", "geography_col": null, "int64_col": "55555", "int64_too": "0", "numeric_col": "5.555555", "float64_col": "555.555", "rowindex": 5, "rowindex_2": 5, "string_col": "Güten Tag!", "time_col": "15:16:17.181921", "timestamp_col": "1980-03-14T15:16:17.181921Z"}
7-
{"bool_col": true, "bytes_col": "SGVsbG8JQmlnRnJhbWVzIQc=", "date_col": "2023-05-23", "datetime_col": "2023-05-23 11:37:01", "geography_col": "MULTIPOINT (20 20, 10 40, 40 30, 30 10)", "int64_col": "101202303", "int64_too": "2", "numeric_col": "-10.090807", "float64_col": "-123.456", "rowindex": 6, "rowindex_2": 6, "string_col": "capitalize, This ", "time_col": "01:02:03.456789", "timestamp_col": "2023-05-23T11:42:55.000001Z"}
8-
{"bool_col": true, "bytes_col": null, "date_col": "2038-01-20", "datetime_col": "2038-01-19 03:14:08", "geography_col": null, "int64_col": "-214748367", "int64_too": "2", "numeric_col": "11111111.1", "float64_col": "42.42", "rowindex": 7, "rowindex_2": 7, "string_col": " سلام", "time_col": "12:00:00.000001", "timestamp_col": "2038-01-19T03:14:17.999999Z"}
9-
{"bool_col": false, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": "2", "int64_too": "1", "numeric_col": null, "float64_col": "6.87", "rowindex": 8, "rowindex_2": 8, "string_col": "T", "time_col": null, "timestamp_col": null}
1+
{"bool_col": true, "bytes_col": "SGVsbG8sIFdvcmxkIQ==", "date_col": "2021-07-21", "datetime_col": "2021-07-21 11:39:45", "geography_col": "POINT(-122.0838511 37.3860517)", "int64_col": "123456789", "int64_too": "0", "numeric_col": "1.23456789", "float64_col": "1.25", "rowindex": 0, "rowindex_2": 0, "string_col": "Hello, World!", "time_col": "11:41:43.076160", "timestamp_col": "2021-07-21T17:43:43.945289Z"}
2+
{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "1991-02-03", "datetime_col": "1991-01-02 03:45:06", "geography_col": "POINT(-71.104 42.315)", "int64_col": "-987654321", "int64_too": "1", "numeric_col": "1.23456789", "float64_col": "2.51", "rowindex": 1, "rowindex_2": 1, "string_col": "こんにちは", "time_col": "11:14:34.701606", "timestamp_col": "2021-07-21T17:43:43.945289Z"}
3+
{"bool_col": true, "bytes_col": "wqFIb2xhIE11bmRvIQ==", "date_col": "2023-03-01", "datetime_col": "2023-03-01 10:55:13", "geography_col": "POINT(-0.124474760143016 51.5007826749545)", "int64_col": "314159", "int64_too": "0", "numeric_col": "101.1010101", "float64_col": "2.5e10", "rowindex": 2, "rowindex_2": 2, "string_col": " ¡Hola Mundo! ", "time_col": "23:59:59.999999", "timestamp_col": "2023-03-01T10:55:13.250125Z"}
4+
{"bool_col": null, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": null, "int64_too": "1", "numeric_col": null, "float64_col": null, "rowindex": 3, "rowindex_2": 3, "string_col": null, "time_col": null, "timestamp_col": null}
5+
{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "2021-07-21", "datetime_col": null, "geography_col": null, "int64_col": "-234892", "int64_too": "-2345", "numeric_col": null, "float64_col": null, "rowindex": 4, "rowindex_2": 4, "string_col": "Hello, World!", "time_col": null, "timestamp_col": null}
6+
{"bool_col": false, "bytes_col": "R8O8dGVuIFRhZw==", "date_col": "1980-03-14", "datetime_col": "1980-03-14 15:16:17", "geography_col": null, "int64_col": "55555", "int64_too": "0", "numeric_col": "5.555555", "float64_col": "555.555", "rowindex": 5, "rowindex_2": 5, "string_col": "Güten Tag!", "time_col": "15:16:17.181921", "timestamp_col": "1980-03-14T15:16:17.181921Z"}
7+
{"bool_col": true, "bytes_col": "SGVsbG8JQmlnRnJhbWVzIQc=", "date_col": "2023-05-23", "datetime_col": "2023-05-23 11:37:01", "geography_col": "LINESTRING(-0.127959 51.507728, -0.127026 51.507473)", "int64_col": "101202303", "int64_too": "2", "numeric_col": "-10.090807", "float64_col": "-123.456", "rowindex": 6, "rowindex_2": 6, "string_col": "capitalize, This ", "time_col": "01:02:03.456789", "timestamp_col": "2023-05-23T11:42:55.000001Z"}
8+
{"bool_col": true, "bytes_col": null, "date_col": "2038-01-20", "datetime_col": "2038-01-19 03:14:08", "geography_col": null, "int64_col": "-214748367", "int64_too": "2", "numeric_col": "11111111.1", "float64_col": "42.42", "rowindex": 7, "rowindex_2": 7, "string_col": " سلام", "time_col": "12:00:00.000001", "timestamp_col": "2038-01-19T03:14:17.999999Z"}
9+
{"bool_col": false, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": "2", "int64_too": "1", "numeric_col": null, "float64_col": "6.87", "rowindex": 8, "rowindex_2": 8, "string_col": "T", "time_col": null, "timestamp_col": null}

tests/unit/core/compile/sqlglot/conftest.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,101 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pathlib
16+
1517
import pandas as pd
18+
import pyarrow as pa
1619
import pytest
1720

21+
from bigframes import dtypes
22+
import tests.system.utils
23+
24+
CURRENT_DIR = pathlib.Path(__file__).parent
25+
DATA_DIR = CURRENT_DIR.parent.parent.parent.parent / "data"
26+
1827

19-
@pytest.fixture(scope="module")
28+
@pytest.fixture(scope="session")
2029
def compiler_session():
2130
from . import compiler_session
2231

2332
return compiler_session.SQLCompilerSession()
2433

2534

26-
@pytest.fixture(scope="module")
27-
def all_types_df() -> pd.DataFrame:
28-
# TODO: all types pandas dataframes
35+
@pytest.fixture(scope="session")
36+
def scalars_types_pandas_df() -> pd.DataFrame:
37+
"""Returns a pandas DataFrame containing all scalar types and using the `rowindex`
38+
column as the index."""
2939
# TODO: add tests for empty dataframes
40+
df = pd.read_json(
41+
DATA_DIR / "scalars.jsonl",
42+
lines=True,
43+
)
44+
tests.system.utils.convert_pandas_dtypes(df, bytes_col=True)
45+
46+
df = df.set_index("rowindex", drop=False)
47+
return df
48+
49+
50+
@pytest.fixture(scope="session")
51+
def nested_structs_pandas_df() -> pd.DataFrame:
52+
"""Returns a pandas DataFrame containing STRUCT types and using the `id`
53+
column as the index."""
54+
55+
df = pd.read_json(
56+
DATA_DIR / "nested_structs.jsonl",
57+
lines=True,
58+
)
59+
df = df.set_index("id")
60+
61+
address_struct_schema = pa.struct(
62+
[pa.field("city", pa.string()), pa.field("country", pa.string())]
63+
)
64+
person_struct_schema = pa.struct(
65+
[
66+
pa.field("name", pa.string()),
67+
pa.field("age", pa.int64()),
68+
pa.field("address", address_struct_schema),
69+
]
70+
)
71+
df["person"] = df["person"].astype(pd.ArrowDtype(person_struct_schema))
72+
return df
73+
74+
75+
@pytest.fixture(scope="session")
76+
def repeated_pandas_df() -> pd.DataFrame:
77+
"""Returns a pandas DataFrame containing LIST types and using the `rowindex`
78+
column as the index."""
79+
80+
df = pd.read_json(
81+
DATA_DIR / "repeated.jsonl",
82+
lines=True,
83+
)
84+
df = df.set_index("rowindex")
85+
return df
86+
87+
88+
@pytest.fixture(scope="session")
89+
def json_pandas_df() -> pd.DataFrame:
90+
"""Returns a pandas DataFrame containing JSON types and using the `rowindex`
91+
column as the index."""
92+
json_data = [
93+
"null",
94+
"true",
95+
"100",
96+
"0.98",
97+
'"a string"',
98+
"[]",
99+
"[1, 2, 3]",
100+
'[{"a": 1}, {"a": 2}, {"a": null}, {}]',
101+
'"100"',
102+
'{"date": "2024-07-16"}',
103+
'{"int_value": 2, "null_filed": null}',
104+
'{"list_data": [10, 20, 30]}',
105+
]
30106
df = pd.DataFrame(
31107
{
32-
"int1": pd.Series([1, 2, 3], dtype="Int64"),
33-
"int2": pd.Series([-10, 20, 30], dtype="Int64"),
34-
"bools": pd.Series([True, None, False], dtype="boolean"),
35-
"strings": pd.Series(["b", "aa", "ccc"], dtype="string[pyarrow]"),
108+
"json_col": pd.Series(json_data, dtype=dtypes.JSON_DTYPE),
36109
},
110+
index=pd.Series(range(len(json_data)), dtype=dtypes.INT_DTYPE),
37111
)
38-
# add more complexity index.
39-
df.index = df.index.astype("Int64")
40112
return df

0 commit comments

Comments
 (0)