Skip to content

Commit 3388191

Browse files
authored
feat: support names parameter in read_csv for bigquery engine (#1659)
1 parent ae312db commit 3388191

File tree

7 files changed

+205
-49
lines changed

7 files changed

+205
-49
lines changed

bigframes/core/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def get_axis_number(axis: typing.Union[str, int]) -> typing.Literal[0, 1]:
4141
raise ValueError(f"Not a valid axis: {axis}")
4242

4343

44-
def is_list_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Sequence]:
45-
return pd.api.types.is_list_like(obj)
44+
def is_list_like(
45+
obj: typing.Any, allow_sets: bool = True
46+
) -> typing_extensions.TypeGuard[typing.Sequence]:
47+
return pd.api.types.is_list_like(obj, allow_sets=allow_sets)
4648

4749

4850
def is_dict_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Mapping]:

bigframes/session/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
from collections import abc
1920
import datetime
2021
import logging
2122
import os
@@ -569,7 +570,7 @@ def read_gbq_table(
569570
columns = col_order
570571

571572
return self._loader.read_gbq_table(
572-
query=query,
573+
table_id=query,
573574
index_col=index_col,
574575
columns=columns,
575576
max_results=max_results,
@@ -953,14 +954,21 @@ def _read_csv_w_bigquery_engine(
953954
native CSV loading capabilities, making it suitable for large datasets
954955
that may not fit into local memory.
955956
"""
956-
957-
if any(param is not None for param in (dtype, names)):
958-
not_supported = ("dtype", "names")
957+
if dtype is not None:
959958
raise NotImplementedError(
960-
f"BigQuery engine does not support these arguments: {not_supported}. "
959+
f"BigQuery engine does not support the `dtype` argument."
961960
f"{constants.FEEDBACK_LINK}"
962961
)
963962

963+
if names is not None:
964+
if len(names) != len(set(names)):
965+
raise ValueError("Duplicated names are not allowed.")
966+
if not (
967+
bigframes.core.utils.is_list_like(names, allow_sets=False)
968+
or isinstance(names, abc.KeysView)
969+
):
970+
raise ValueError("Names should be an ordered collection.")
971+
964972
if index_col is True:
965973
raise ValueError("The value of index_col couldn't be 'True'")
966974

@@ -1004,11 +1012,9 @@ def _read_csv_w_bigquery_engine(
10041012
elif header > 0:
10051013
job_config.skip_leading_rows = header + 1
10061014

1007-
return self._loader.read_bigquery_load_job(
1008-
filepath_or_buffer,
1009-
job_config=job_config,
1010-
index_col=index_col,
1011-
columns=columns,
1015+
table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config)
1016+
return self._loader.read_gbq_table(
1017+
table_id, index_col=index_col, columns=columns, names=names
10121018
)
10131019

10141020
def read_pickle(
@@ -1049,8 +1055,8 @@ def read_parquet(
10491055
job_config = bigquery.LoadJobConfig()
10501056
job_config.source_format = bigquery.SourceFormat.PARQUET
10511057
job_config.labels = {"bigframes-api": "read_parquet"}
1052-
1053-
return self._loader.read_bigquery_load_job(path, job_config=job_config)
1058+
table_id = self._loader.load_file(path, job_config=job_config)
1059+
return self._loader.read_gbq_table(table_id)
10541060
else:
10551061
if "*" in path:
10561062
raise ValueError(
@@ -1121,10 +1127,8 @@ def read_json(
11211127
job_config.encoding = encoding
11221128
job_config.labels = {"bigframes-api": "read_json"}
11231129

1124-
return self._loader.read_bigquery_load_job(
1125-
path_or_buf,
1126-
job_config=job_config,
1127-
)
1130+
table_id = self._loader.load_file(path_or_buf, job_config=job_config)
1131+
return self._loader.read_gbq_table(table_id)
11281132
else:
11291133
if any(arg in kwargs for arg in ("chunksize", "iterator")):
11301134
raise NotImplementedError(

bigframes/session/_io/bigquery/read_gbq_table.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def get_index_cols(
235235
| Iterable[int]
236236
| int
237237
| bigframes.enums.DefaultIndexKind,
238+
*,
239+
names: Optional[Iterable[str]] = None,
238240
) -> List[str]:
239241
"""
240242
If we can get a total ordering from the table, such as via primary key
@@ -245,6 +247,14 @@ def get_index_cols(
245247
# Transform index_col -> index_cols so we have a variable that is
246248
# always a list of column names (possibly empty).
247249
schema_len = len(table.schema)
250+
251+
# If the `names` is provided, the index_col provided by the user is the new
252+
# name, so we need to rename it to the original name in the table schema.
253+
renamed_schema: Optional[Dict[str, str]] = None
254+
if names is not None:
255+
assert len(list(names)) == schema_len
256+
renamed_schema = {name: field.name for name, field in zip(names, table.schema)}
257+
248258
index_cols: List[str] = []
249259
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
250260
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
@@ -261,6 +271,8 @@ def get_index_cols(
261271
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
262272
)
263273
elif isinstance(index_col, str):
274+
if renamed_schema is not None:
275+
index_col = renamed_schema.get(index_col, index_col)
264276
index_cols = [index_col]
265277
elif isinstance(index_col, int):
266278
if not 0 <= index_col < schema_len:
@@ -272,6 +284,8 @@ def get_index_cols(
272284
elif isinstance(index_col, Iterable):
273285
for item in index_col:
274286
if isinstance(item, str):
287+
if renamed_schema is not None:
288+
item = renamed_schema.get(item, item)
275289
index_cols.append(item)
276290
elif isinstance(item, int):
277291
if not 0 <= item < schema_len:

bigframes/session/loader.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,15 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):
348348

349349
def read_gbq_table(
350350
self,
351-
query: str,
351+
table_id: str,
352352
*,
353353
index_col: Iterable[str]
354354
| str
355355
| Iterable[int]
356356
| int
357357
| bigframes.enums.DefaultIndexKind = (),
358358
columns: Iterable[str] = (),
359+
names: Optional[Iterable[str]] = None,
359360
max_results: Optional[int] = None,
360361
api_name: str = "read_gbq_table",
361362
use_cache: bool = True,
@@ -375,7 +376,7 @@ def read_gbq_table(
375376
)
376377

377378
table_ref = google.cloud.bigquery.table.TableReference.from_string(
378-
query, default_project=self._bqclient.project
379+
table_id, default_project=self._bqclient.project
379380
)
380381

381382
columns = list(columns)
@@ -411,12 +412,37 @@ def read_gbq_table(
411412
f"Column '{key}' of `columns` not found in this table. Did you mean '{possibility}'?"
412413
)
413414

415+
# TODO(b/408499371): check `names` work with `use_cols` for read_csv method.
416+
if names is not None:
417+
len_names = len(list(names))
418+
len_columns = len(table.schema)
419+
if len_names > len_columns:
420+
raise ValueError(
421+
f"Too many columns specified: expected {len_columns}"
422+
f" and found {len_names}"
423+
)
424+
elif len_names < len_columns:
425+
if (
426+
isinstance(index_col, bigframes.enums.DefaultIndexKind)
427+
or index_col != ()
428+
):
429+
raise KeyError(
430+
"When providing both `index_col` and `names`, ensure the "
431+
"number of `names` matches the number of columns in your "
432+
"data."
433+
)
434+
index_col = range(len_columns - len_names)
435+
names = [
436+
field.name for field in table.schema[: len_columns - len_names]
437+
] + list(names)
438+
414439
# Converting index_col into a list of column names requires
415440
# the table metadata because we might use the primary keys
416441
# when constructing the index.
417442
index_cols = bf_read_gbq_table.get_index_cols(
418443
table=table,
419444
index_col=index_col,
445+
names=names,
420446
)
421447
_check_column_duplicates(index_cols, columns)
422448

@@ -443,15 +469,15 @@ def read_gbq_table(
443469
# TODO(b/338419730): We don't need to fallback to a query for wildcard
444470
# tables if we allow some non-determinism when time travel isn't supported.
445471
if max_results is not None or bf_io_bigquery.is_table_with_wildcard_suffix(
446-
query
472+
table_id
447473
):
448474
# TODO(b/338111344): If we are running a query anyway, we might as
449475
# well generate ROW_NUMBER() at the same time.
450476
all_columns: Iterable[str] = (
451477
itertools.chain(index_cols, columns) if columns else ()
452478
)
453479
query = bf_io_bigquery.to_query(
454-
query,
480+
table_id,
455481
columns=all_columns,
456482
sql_predicate=bf_io_bigquery.compile_filters(filters)
457483
if filters
@@ -561,6 +587,15 @@ def read_gbq_table(
561587
index_names = [None]
562588

563589
value_columns = [col for col in array_value.column_ids if col not in index_cols]
590+
if names is not None:
591+
renamed_cols: Dict[str, str] = {
592+
col: new_name for col, new_name in zip(array_value.column_ids, names)
593+
}
594+
index_names = [
595+
renamed_cols.get(index_col, index_col) for index_col in index_cols
596+
]
597+
value_columns = [renamed_cols.get(col, col) for col in value_columns]
598+
564599
block = blocks.Block(
565600
array_value,
566601
index_columns=index_cols,
@@ -576,18 +611,12 @@ def read_gbq_table(
576611
df.sort_index()
577612
return df
578613

579-
def read_bigquery_load_job(
614+
def load_file(
580615
self,
581616
filepath_or_buffer: str | IO["bytes"],
582617
*,
583618
job_config: bigquery.LoadJobConfig,
584-
index_col: Iterable[str]
585-
| str
586-
| Iterable[int]
587-
| int
588-
| bigframes.enums.DefaultIndexKind = (),
589-
columns: Iterable[str] = (),
590-
) -> dataframe.DataFrame:
619+
) -> str:
591620
# Need to create session table beforehand
592621
table = self._storage_manager.create_temp_table(_PLACEHOLDER_SCHEMA)
593622
# but, we just overwrite the placeholder schema immediately with the load job
@@ -615,16 +644,7 @@ def read_bigquery_load_job(
615644

616645
self._start_generic_job(load_job)
617646
table_id = f"{table.project}.{table.dataset_id}.{table.table_id}"
618-
619-
# The BigQuery REST API for tables.get doesn't take a session ID, so we
620-
# can't get the schema for a temp table that way.
621-
622-
return self.read_gbq_table(
623-
query=table_id,
624-
index_col=index_col,
625-
columns=columns,
626-
api_name="read_gbq_table",
627-
)
647+
return table_id
628648

629649
def read_gbq_query(
630650
self,

0 commit comments

Comments
 (0)