Skip to content

Commit 855031a

Browse files
authored
fix: correct read_csv behaviours with use_cols, names, index_col (#1804)
* fix: correct read_csv behaviours with use_cols, names, index_col parameters * fix test_default_index_warning_not_raised_by_read_gbq_primary_key * refactor read_gbq_table for more readable * fix presubmit
1 parent e403528 commit 855031a

File tree

3 files changed

+283
-110
lines changed

3 files changed

+283
-110
lines changed

bigframes/session/_io/bigquery/read_gbq_table.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,25 +243,17 @@ def get_index_cols(
243243
| int
244244
| bigframes.enums.DefaultIndexKind,
245245
*,
246-
names: Optional[Iterable[str]] = None,
246+
rename_to_schema: Optional[Dict[str, str]] = None,
247247
) -> List[str]:
248248
"""
249249
If we can get a total ordering from the table, such as via primary key
250250
column(s), then return those too so that ordering generation can be
251251
avoided.
252252
"""
253-
254253
# Transform index_col -> index_cols so we have a variable that is
255254
# always a list of column names (possibly empty).
256255
schema_len = len(table.schema)
257256

258-
# If the `names` is provided, the index_col provided by the user is the new
259-
# name, so we need to rename it to the original name in the table schema.
260-
renamed_schema: Optional[Dict[str, str]] = None
261-
if names is not None:
262-
assert len(list(names)) == schema_len
263-
renamed_schema = {name: field.name for name, field in zip(names, table.schema)}
264-
265257
index_cols: List[str] = []
266258
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
267259
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
@@ -278,8 +270,8 @@ def get_index_cols(
278270
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
279271
)
280272
elif isinstance(index_col, str):
281-
if renamed_schema is not None:
282-
index_col = renamed_schema.get(index_col, index_col)
273+
if rename_to_schema is not None:
274+
index_col = rename_to_schema.get(index_col, index_col)
283275
index_cols = [index_col]
284276
elif isinstance(index_col, int):
285277
if not 0 <= index_col < schema_len:
@@ -291,8 +283,8 @@ def get_index_cols(
291283
elif isinstance(index_col, Iterable):
292284
for item in index_col:
293285
if isinstance(item, str):
294-
if renamed_schema is not None:
295-
item = renamed_schema.get(item, item)
286+
if rename_to_schema is not None:
287+
item = rename_to_schema.get(item, item)
296288
index_cols.append(item)
297289
elif isinstance(item, int):
298290
if not 0 <= item < schema_len:

bigframes/session/loader.py

Lines changed: 151 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,35 @@ def _to_index_cols(
9696
return index_cols
9797

9898

99-
def _check_column_duplicates(
100-
index_cols: Iterable[str], columns: Iterable[str], index_col_in_columns: bool
101-
) -> Iterable[str]:
102-
"""Validates and processes index and data columns for duplicates and overlap.
99+
def _check_duplicates(name: str, columns: Optional[Iterable[str]] = None):
100+
"""Check for duplicate column names in the provided iterable."""
101+
if columns is None:
102+
return
103+
columns_list = list(columns)
104+
set_columns = set(columns_list)
105+
if len(columns_list) > len(set_columns):
106+
raise ValueError(
107+
f"The '{name}' argument contains duplicate names. "
108+
f"All column names specified in '{name}' must be unique."
109+
)
103110

104-
This function performs two main tasks:
105-
1. Ensures there are no duplicate column names within the `index_cols` list
106-
or within the `columns` list.
107-
2. Based on the `index_col_in_columns` flag, it validates the relationship
108-
between `index_cols` and `columns`.
111+
112+
def _check_index_col_param(
113+
index_cols: Iterable[str],
114+
columns: Iterable[str],
115+
*,
116+
table_columns: Optional[Iterable[str]] = None,
117+
index_col_in_columns: Optional[bool] = False,
118+
):
119+
"""Checks for duplicates in `index_cols` and resolves overlap with `columns`.
109120
110121
Args:
111122
index_cols (Iterable[str]):
112-
An iterable of column names designated as the index.
123+
Column names designated as the index columns.
113124
columns (Iterable[str]):
114-
An iterable of column names designated as the data columns.
125+
Used column names from table_columns.
126+
table_columns (Iterable[str]):
127+
A full list of column names in the table schema.
115128
index_col_in_columns (bool):
116129
A flag indicating how to handle overlap between `index_cols` and
117130
`columns`.
@@ -121,40 +134,97 @@ def _check_column_duplicates(
121134
`columns`. An error is raised if an index column is not found
122135
in the `columns` list.
123136
"""
124-
index_cols_list = list(index_cols) if index_cols is not None else []
125-
columns_list = list(columns) if columns is not None else []
126-
set_index = set(index_cols_list)
127-
set_columns = set(columns_list)
137+
_check_duplicates("index_col", index_cols)
128138

129-
if len(index_cols_list) > len(set_index):
130-
raise ValueError(
131-
"The 'index_col' argument contains duplicate names. "
132-
"All column names specified in 'index_col' must be unique."
133-
)
139+
if columns is not None and len(list(columns)) > 0:
140+
set_index = set(list(index_cols) if index_cols is not None else [])
141+
set_columns = set(list(columns) if columns is not None else [])
134142

135-
if len(columns_list) == 0:
136-
return columns
143+
if index_col_in_columns:
144+
if not set_index.issubset(set_columns):
145+
raise ValueError(
146+
f"The specified index column(s) were not found: {set_index - set_columns}. "
147+
f"Available columns are: {set_columns}"
148+
)
149+
else:
150+
if not set_index.isdisjoint(set_columns):
151+
raise ValueError(
152+
"Found column names that exist in both 'index_col' and 'columns' arguments. "
153+
"These arguments must specify distinct sets of columns."
154+
)
137155

138-
if len(columns_list) > len(set_columns):
139-
raise ValueError(
140-
"The 'columns' argument contains duplicate names. "
141-
"All column names specified in 'columns' must be unique."
142-
)
156+
if not index_col_in_columns and table_columns is not None:
157+
for key in index_cols:
158+
if key not in table_columns:
159+
possibility = min(
160+
table_columns,
161+
key=lambda item: bigframes._tools.strings.levenshtein_distance(
162+
key, item
163+
),
164+
)
165+
raise ValueError(
166+
f"Column '{key}' of `index_col` not found in this table. Did you mean '{possibility}'?"
167+
)
143168

144-
if index_col_in_columns:
145-
if not set_index.issubset(set_columns):
146-
raise ValueError(
147-
f"The specified index column(s) were not found: {set_index - set_columns}. "
148-
f"Available columns are: {set_columns}"
169+
170+
def _check_columns_param(columns: Iterable[str], table_columns: Iterable[str]):
171+
"""Validates that the specified columns are present in the table columns.
172+
173+
Args:
174+
columns (Iterable[str]):
175+
Used column names from table_columns.
176+
table_columns (Iterable[str]):
177+
A full list of column names in the table schema.
178+
Raises:
179+
ValueError: If any column in `columns` is not found in the table columns.
180+
"""
181+
for column_name in columns:
182+
if column_name not in table_columns:
183+
possibility = min(
184+
table_columns,
185+
key=lambda item: bigframes._tools.strings.levenshtein_distance(
186+
column_name, item
187+
),
149188
)
150-
return [col for col in columns if col not in set_index]
151-
else:
152-
if not set_index.isdisjoint(set_columns):
153189
raise ValueError(
154-
"Found column names that exist in both 'index_col' and 'columns' arguments. "
155-
"These arguments must specify distinct sets of columns."
190+
f"Column '{column_name}' is not found. Did you mean '{possibility}'?"
156191
)
157-
return columns
192+
193+
194+
def _check_names_param(
195+
names: Iterable[str],
196+
index_col: Iterable[str]
197+
| str
198+
| Iterable[int]
199+
| int
200+
| bigframes.enums.DefaultIndexKind,
201+
columns: Iterable[str],
202+
table_columns: Iterable[str],
203+
):
204+
len_names = len(list(names))
205+
len_table_columns = len(list(table_columns))
206+
len_columns = len(list(columns))
207+
if len_names > len_table_columns:
208+
raise ValueError(
209+
f"Too many columns specified: expected {len_table_columns}"
210+
f" and found {len_names}"
211+
)
212+
elif len_names < len_table_columns:
213+
if isinstance(index_col, bigframes.enums.DefaultIndexKind) or index_col != ():
214+
raise KeyError(
215+
"When providing both `index_col` and `names`, ensure the "
216+
"number of `names` matches the number of columns in your "
217+
"data."
218+
)
219+
if len_columns != 0:
220+
# The 'columns' must be identical to the 'names'. If not, raise an error.
221+
if len_columns != len_names:
222+
raise ValueError(
223+
"Number of passed names did not match number of header "
224+
"fields in the file"
225+
)
226+
if set(list(names)) != set(list(columns)):
227+
raise ValueError("Usecols do not match columns")
158228

159229

160230
@dataclasses.dataclass
@@ -545,11 +615,14 @@ def read_gbq_table(
545615
f"`max_results` should be a positive number, got {max_results}."
546616
)
547617

618+
_check_duplicates("columns", columns)
619+
548620
table_ref = google.cloud.bigquery.table.TableReference.from_string(
549621
table_id, default_project=self._bqclient.project
550622
)
551623

552624
columns = list(columns)
625+
include_all_columns = columns is None or len(columns) == 0
553626
filters = typing.cast(list, list(filters))
554627

555628
# ---------------------------------
@@ -563,72 +636,58 @@ def read_gbq_table(
563636
cache=self._df_snapshot,
564637
use_cache=use_cache,
565638
)
566-
table_column_names = {field.name for field in table.schema}
567639

568640
if table.location.casefold() != self._storage_manager.location.casefold():
569641
raise ValueError(
570642
f"Current session is in {self._storage_manager.location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}"
571643
)
572644

573-
for key in columns:
574-
if key not in table_column_names:
575-
possibility = min(
576-
table_column_names,
577-
key=lambda item: bigframes._tools.strings.levenshtein_distance(
578-
key, item
579-
),
580-
)
581-
raise ValueError(
582-
f"Column '{key}' of `columns` not found in this table. Did you mean '{possibility}'?"
583-
)
584-
585-
# TODO(b/408499371): check `names` work with `use_cols` for read_csv method.
645+
table_column_names = [field.name for field in table.schema]
646+
rename_to_schema: Optional[Dict[str, str]] = None
586647
if names is not None:
648+
_check_names_param(names, index_col, columns, table_column_names)
649+
650+
# Additional unnamed columns is going to set as index columns
587651
len_names = len(list(names))
588-
len_columns = len(table.schema)
589-
if len_names > len_columns:
590-
raise ValueError(
591-
f"Too many columns specified: expected {len_columns}"
592-
f" and found {len_names}"
593-
)
594-
elif len_names < len_columns:
595-
if (
596-
isinstance(index_col, bigframes.enums.DefaultIndexKind)
597-
or index_col != ()
598-
):
599-
raise KeyError(
600-
"When providing both `index_col` and `names`, ensure the "
601-
"number of `names` matches the number of columns in your "
602-
"data."
603-
)
604-
index_col = range(len_columns - len_names)
652+
len_schema = len(table.schema)
653+
if len(columns) == 0 and len_names < len_schema:
654+
index_col = range(len_schema - len_names)
605655
names = [
606-
field.name for field in table.schema[: len_columns - len_names]
656+
field.name for field in table.schema[: len_schema - len_names]
607657
] + list(names)
608658

659+
assert len_schema >= len_names
660+
assert len_names >= len(columns)
661+
662+
table_column_names = table_column_names[: len(list(names))]
663+
rename_to_schema = dict(zip(list(names), table_column_names))
664+
665+
if len(columns) != 0:
666+
if names is None:
667+
_check_columns_param(columns, table_column_names)
668+
else:
669+
_check_columns_param(columns, names)
670+
names = columns
671+
assert rename_to_schema is not None
672+
columns = [rename_to_schema[renamed_name] for renamed_name in columns]
673+
609674
# Converting index_col into a list of column names requires
610675
# the table metadata because we might use the primary keys
611676
# when constructing the index.
612677
index_cols = bf_read_gbq_table.get_index_cols(
613678
table=table,
614679
index_col=index_col,
615-
names=names,
680+
rename_to_schema=rename_to_schema,
616681
)
617-
columns = list(
618-
_check_column_duplicates(index_cols, columns, index_col_in_columns)
682+
_check_index_col_param(
683+
index_cols,
684+
columns,
685+
table_columns=table_column_names,
686+
index_col_in_columns=index_col_in_columns,
619687
)
620-
621-
for key in index_cols:
622-
if key not in table_column_names:
623-
possibility = min(
624-
table_column_names,
625-
key=lambda item: bigframes._tools.strings.levenshtein_distance(
626-
key, item
627-
),
628-
)
629-
raise ValueError(
630-
f"Column '{key}' of `index_col` not found in this table. Did you mean '{possibility}'?"
631-
)
688+
if index_col_in_columns and not include_all_columns:
689+
set_index = set(list(index_cols) if index_cols is not None else [])
690+
columns = [col for col in columns if col not in set_index]
632691

633692
# -----------------------------
634693
# Optionally, execute the query
@@ -715,7 +774,7 @@ def read_gbq_table(
715774
metadata_only=not self._scan_index_uniqueness,
716775
)
717776
schema = schemata.ArraySchema.from_bq_table(table)
718-
if columns:
777+
if not include_all_columns:
719778
schema = schema.select(index_cols + columns)
720779
array_value = core.ArrayValue.from_table(
721780
table,
@@ -767,14 +826,14 @@ def read_gbq_table(
767826

768827
value_columns = [col for col in array_value.column_ids if col not in index_cols]
769828
if names is not None:
770-
renamed_cols: Dict[str, str] = {
771-
col: new_name for col, new_name in zip(array_value.column_ids, names)
772-
}
829+
assert rename_to_schema is not None
830+
schema_to_rename = {value: key for key, value in rename_to_schema.items()}
773831
if index_col != bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
774832
index_names = [
775-
renamed_cols.get(index_col, index_col) for index_col in index_cols
833+
schema_to_rename.get(index_col, index_col)
834+
for index_col in index_cols
776835
]
777-
value_columns = [renamed_cols.get(col, col) for col in value_columns]
836+
value_columns = [schema_to_rename.get(col, col) for col in value_columns]
778837

779838
block = blocks.Block(
780839
array_value,
@@ -898,9 +957,7 @@ def read_gbq_query(
898957
)
899958

900959
index_cols = _to_index_cols(index_col)
901-
columns = _check_column_duplicates(
902-
index_cols, columns, index_col_in_columns=False
903-
)
960+
_check_index_col_param(index_cols, columns)
904961

905962
filters_copy1, filters_copy2 = itertools.tee(filters)
906963
has_filters = len(list(filters_copy1)) != 0

0 commit comments

Comments
 (0)