Skip to content

Commit ba7c313

Browse files
chelsea-lintswast
andauthored
fix: address read_csv with both index_col and use_cols behavior inconsistency with pandas (#1785)
* fix: read_csv with both index_col and use_cols inconsistent with pandas * ensure columns is not list type and avoid flacky ordered of columns * add docstring for index_col_in_columns and fix tests --------- Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 86159a7 commit ba7c313

File tree

3 files changed

+164
-36
lines changed

3 files changed

+164
-36
lines changed

bigframes/session/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,11 @@ def _read_csv_w_bigquery_engine(
11701170

11711171
table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config)
11721172
df = self._loader.read_gbq_table(
1173-
table_id, index_col=index_col, columns=columns, names=names
1173+
table_id,
1174+
index_col=index_col,
1175+
columns=columns,
1176+
names=names,
1177+
index_col_in_columns=True,
11741178
)
11751179

11761180
if dtype is not None:

bigframes/session/loader.py

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,31 @@ def _to_index_cols(
9696
return index_cols
9797

9898

99-
def _check_column_duplicates(index_cols: Iterable[str], columns: Iterable[str]):
99+
def _check_column_duplicates(
100+
index_cols: Iterable[str], columns: Iterable[str], index_col_in_columns: bool
101+
) -> Iterable[str]:
102+
"""Validates and processes index and data columns for duplicates and overlap.
103+
104+
This function performs two main tasks:
105+
1. Ensures there are no duplicate column names within the `index_cols` list
106+
or within the `columns` list.
107+
2. Based on the `index_col_in_columns` flag, it validates the relationship
108+
between `index_cols` and `columns`.
109+
110+
Args:
111+
index_cols (Iterable[str]):
112+
An iterable of column names designated as the index.
113+
columns (Iterable[str]):
114+
An iterable of column names designated as the data columns.
115+
index_col_in_columns (bool):
116+
A flag indicating how to handle overlap between `index_cols` and
117+
`columns`.
118+
- If `False`, the two lists must be disjoint (contain no common
119+
elements). An error is raised if any overlap is found.
120+
- If `True`, `index_cols` is expected to be a subset of
121+
`columns`. An error is raised if an index column is not found
122+
in the `columns` list.
123+
"""
100124
index_cols_list = list(index_cols) if index_cols is not None else []
101125
columns_list = list(columns) if columns is not None else []
102126
set_index = set(index_cols_list)
@@ -108,17 +132,29 @@ def _check_column_duplicates(index_cols: Iterable[str], columns: Iterable[str]):
108132
"All column names specified in 'index_col' must be unique."
109133
)
110134

135+
if len(columns_list) == 0:
136+
return columns
137+
111138
if len(columns_list) > len(set_columns):
112139
raise ValueError(
113140
"The 'columns' argument contains duplicate names. "
114141
"All column names specified in 'columns' must be unique."
115142
)
116143

117-
if not set_index.isdisjoint(set_columns):
118-
raise ValueError(
119-
"Found column names that exist in both 'index_col' and 'columns' arguments. "
120-
"These arguments must specify distinct sets of columns."
121-
)
144+
if index_col_in_columns:
145+
if not set_index.issubset(set_columns):
146+
raise ValueError(
147+
f"The specified index column(s) were not found: {set_index - set_columns}. "
148+
f"Available columns are: {set_columns}"
149+
)
150+
return [col for col in columns if col not in set_index]
151+
else:
152+
if not set_index.isdisjoint(set_columns):
153+
raise ValueError(
154+
"Found column names that exist in both 'index_col' and 'columns' arguments. "
155+
"These arguments must specify distinct sets of columns."
156+
)
157+
return columns
122158

123159

124160
@dataclasses.dataclass
@@ -391,6 +427,7 @@ def read_gbq_table( # type: ignore[overload-overlap]
391427
dry_run: Literal[False] = ...,
392428
force_total_order: Optional[bool] = ...,
393429
n_rows: Optional[int] = None,
430+
index_col_in_columns: bool = False,
394431
) -> dataframe.DataFrame:
395432
...
396433

@@ -413,6 +450,7 @@ def read_gbq_table(
413450
dry_run: Literal[True] = ...,
414451
force_total_order: Optional[bool] = ...,
415452
n_rows: Optional[int] = None,
453+
index_col_in_columns: bool = False,
416454
) -> pandas.Series:
417455
...
418456

@@ -434,7 +472,67 @@ def read_gbq_table(
434472
dry_run: bool = False,
435473
force_total_order: Optional[bool] = None,
436474
n_rows: Optional[int] = None,
475+
index_col_in_columns: bool = False,
437476
) -> dataframe.DataFrame | pandas.Series:
477+
"""Read a BigQuery table into a BigQuery DataFrames DataFrame.
478+
479+
This method allows you to create a DataFrame from a BigQuery table.
480+
You can specify the columns to load, an index column, and apply
481+
filters.
482+
483+
Args:
484+
table_id (str):
485+
The identifier of the BigQuery table to read.
486+
index_col (Iterable[str] | str | Iterable[int] | int | bigframes.enums.DefaultIndexKind, optional):
487+
The column(s) to use as the index for the DataFrame. This can be
488+
a single column name or a list of column names. If not provided,
489+
a default index will be used based on the session's
490+
``default_index_type``.
491+
columns (Iterable[str], optional):
492+
The columns to read from the table. If not specified, all
493+
columns will be read.
494+
names (Optional[Iterable[str]], optional):
495+
A list of column names to use for the resulting DataFrame. This
496+
is useful if you want to rename the columns as you read the
497+
data.
498+
max_results (Optional[int], optional):
499+
The maximum number of rows to retrieve from the table. If not
500+
specified, all rows will be loaded.
501+
use_cache (bool, optional):
502+
Whether to use cached results for the query. Defaults to True.
503+
Setting this to False will force a re-execution of the query.
504+
filters (third_party_pandas_gbq.FiltersType, optional):
505+
A list of filters to apply to the data. Filters are specified
506+
as a list of tuples, where each tuple contains a column name,
507+
an operator (e.g., '==', '!='), and a value.
508+
enable_snapshot (bool, optional):
509+
If True, a snapshot of the table is used to ensure that the
510+
DataFrame is deterministic, even if the underlying table
511+
changes. Defaults to True.
512+
dry_run (bool, optional):
513+
If True, the function will not actually execute the query but
514+
will instead return statistics about the table. Defaults to False.
515+
force_total_order (Optional[bool], optional):
516+
If True, a total ordering is enforced on the DataFrame, which
517+
can be useful for operations that require a stable row order.
518+
If None, the session's default behavior is used.
519+
n_rows (Optional[int], optional):
520+
The number of rows to consider for type inference and other
521+
metadata operations. This does not limit the number of rows
522+
in the final DataFrame.
523+
index_col_in_columns (bool, optional):
524+
Specifies if the ``index_col`` is also present in the ``columns``
525+
list. Defaults to ``False``.
526+
527+
* If ``False``, ``index_col`` and ``columns`` must specify
528+
distinct sets of columns. An error will be raised if any
529+
column is found in both.
530+
* If ``True``, the column(s) in ``index_col`` are expected to
531+
also be present in the ``columns`` list. This is useful
532+
when the index is selected from the data columns (e.g., in a
533+
``read_csv`` scenario). The column will be used as the
534+
DataFrame's index and removed from the list of value columns.
535+
"""
438536
import bigframes._tools.strings
439537
import bigframes.dataframe as dataframe
440538

@@ -516,7 +614,9 @@ def read_gbq_table(
516614
index_col=index_col,
517615
names=names,
518616
)
519-
_check_column_duplicates(index_cols, columns)
617+
columns = list(
618+
_check_column_duplicates(index_cols, columns, index_col_in_columns)
619+
)
520620

521621
for key in index_cols:
522622
if key not in table_column_names:
@@ -798,7 +898,9 @@ def read_gbq_query(
798898
)
799899

800900
index_cols = _to_index_cols(index_col)
801-
_check_column_duplicates(index_cols, columns)
901+
columns = _check_column_duplicates(
902+
index_cols, columns, index_col_in_columns=False
903+
)
802904

803905
filters_copy1, filters_copy2 = itertools.tee(filters)
804906
has_filters = len(list(filters_copy1)) != 0

tests/system/small/test_session.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,10 +1320,6 @@ def test_read_csv_for_names_less_than_columns(session, df_and_gcs_csv_for_two_co
13201320
assert bf_df.shape == pd_df.shape
13211321
assert bf_df.columns.tolist() == pd_df.columns.tolist()
13221322

1323-
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1324-
# (b/280889935) or guarantee row ordering.
1325-
bf_df = bf_df.sort_index()
1326-
13271323
# Pandas's index name is None, while BigFrames's index name is "rowindex".
13281324
pd_df.index.name = "rowindex"
13291325
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
@@ -1479,41 +1475,70 @@ def test_read_csv_for_gcs_file_w_header(session, df_and_gcs_csv, header):
14791475
def test_read_csv_w_usecols(session, df_and_local_csv):
14801476
# Compares results for pandas and bigframes engines
14811477
scalars_df, path = df_and_local_csv
1478+
usecols = ["rowindex", "bool_col"]
14821479
with open(path, "rb") as buffer:
14831480
bf_df = session.read_csv(
14841481
buffer,
14851482
engine="bigquery",
1486-
usecols=["bool_col"],
1483+
usecols=usecols,
14871484
)
14881485
with open(path, "rb") as buffer:
14891486
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
14901487
pd_df = session.read_csv(
14911488
buffer,
1492-
usecols=["bool_col"],
1489+
usecols=usecols,
14931490
dtype=scalars_df[["bool_col"]].dtypes.to_dict(),
14941491
)
14951492

1496-
# Cannot compare two dataframe due to b/408499371.
1497-
assert len(bf_df.columns) == 1
1498-
assert len(pd_df.columns) == 1
1493+
assert bf_df.shape == pd_df.shape
1494+
assert bf_df.columns.tolist() == pd_df.columns.tolist()
14991495

1496+
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1497+
# (b/280889935) or guarantee row ordering.
1498+
bf_df = bf_df.set_index("rowindex").sort_index()
1499+
pd_df = pd_df.set_index("rowindex")
1500+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
15001501

1501-
@pytest.mark.parametrize(
1502-
"engine",
1503-
[
1504-
pytest.param("bigquery", id="bq_engine"),
1505-
pytest.param(None, id="default_engine"),
1506-
],
1507-
)
1508-
def test_read_csv_local_w_usecols(session, scalars_pandas_df_index, engine):
1509-
with tempfile.TemporaryDirectory() as dir:
1510-
path = dir + "/test_read_csv_local_w_usecols.csv"
1511-
# Using the pandas to_csv method because the BQ one does not support local write.
1512-
scalars_pandas_df_index.to_csv(path, index=False)
15131502

1514-
# df should only have 1 column which is bool_col.
1515-
df = session.read_csv(path, usecols=["bool_col"], engine=engine)
1516-
assert len(df.columns) == 1
1503+
def test_read_csv_w_usecols_and_indexcol(session, df_and_local_csv):
1504+
# Compares results for pandas and bigframes engines
1505+
scalars_df, path = df_and_local_csv
1506+
usecols = ["rowindex", "bool_col"]
1507+
with open(path, "rb") as buffer:
1508+
bf_df = session.read_csv(
1509+
buffer,
1510+
engine="bigquery",
1511+
usecols=usecols,
1512+
index_col="rowindex",
1513+
)
1514+
with open(path, "rb") as buffer:
1515+
# Convert default pandas dtypes to match BigQuery DataFrames dtypes.
1516+
pd_df = session.read_csv(
1517+
buffer,
1518+
usecols=usecols,
1519+
index_col="rowindex",
1520+
dtype=scalars_df[["bool_col"]].dtypes.to_dict(),
1521+
)
1522+
1523+
assert bf_df.shape == pd_df.shape
1524+
assert bf_df.columns.tolist() == pd_df.columns.tolist()
1525+
1526+
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas())
1527+
1528+
1529+
def test_read_csv_w_indexcol_not_in_usecols(session, df_and_local_csv):
1530+
_, path = df_and_local_csv
1531+
with open(path, "rb") as buffer:
1532+
with pytest.raises(
1533+
ValueError,
1534+
match=re.escape("The specified index column(s) were not found"),
1535+
):
1536+
session.read_csv(
1537+
buffer,
1538+
engine="bigquery",
1539+
usecols=["bool_col"],
1540+
index_col="rowindex",
1541+
)
15171542

15181543

15191544
@pytest.mark.parametrize(
@@ -1553,9 +1578,6 @@ def test_read_csv_local_w_encoding(session, penguins_pandas_df_default_index):
15531578
bf_df = session.read_csv(
15541579
path, engine="bigquery", index_col="rowindex", encoding="ISO-8859-1"
15551580
)
1556-
# BigFrames requires `sort_index()` because BigQuery doesn't preserve row IDs
1557-
# (b/280889935) or guarantee row ordering.
1558-
bf_df = bf_df.sort_index()
15591581
pd.testing.assert_frame_equal(
15601582
bf_df.to_pandas(), penguins_pandas_df_default_index
15611583
)

0 commit comments

Comments
 (0)