Skip to content

Commit 2c72c56

Browse files
authored
feat: add parameter shuffle for ml.model_selection.train_test_split (#2030)
* feat: add parameter shuffle for ml.model_selection.train_test_split * mypy * rename
1 parent fc44bc8 commit 2c72c56

File tree

3 files changed

+120
-4
lines changed

3 files changed

+120
-4
lines changed

bigframes/ml/model_selection.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
import inspect
21+
from itertools import chain
2122
import time
2223
from typing import cast, Generator, List, Optional, Union
2324

@@ -36,12 +37,9 @@ def train_test_split(
3637
train_size: Union[float, None] = None,
3738
random_state: Union[int, None] = None,
3839
stratify: Union[bpd.Series, None] = None,
40+
shuffle: bool = True,
3941
) -> List[Union[bpd.DataFrame, bpd.Series]]:
4042

41-
# TODO(garrettwu): scikit-learn throws an error when the dataframes don't have the same
42-
# number of rows. We probably want to do something similar. Now the implementation is based
43-
# on index. We'll move to based on ordering first.
44-
4543
if test_size is None:
4644
if train_size is None:
4745
test_size = 0.25
@@ -61,6 +59,26 @@ def train_test_split(
6159
f"The sum of train_size and test_size exceeds 1.0. train_size: {train_size}. test_size: {test_size}"
6260
)
6361

62+
if not shuffle:
63+
if stratify is not None:
64+
raise ValueError(
65+
"Stratified train/test split is not implemented for shuffle=False"
66+
)
67+
bf_arrays = list(utils.batch_convert_to_bf_equivalent(*arrays))
68+
69+
total_rows = len(bf_arrays[0])
70+
train_rows = int(total_rows * train_size)
71+
test_rows = total_rows - train_rows
72+
73+
return list(
74+
chain.from_iterable(
75+
[
76+
[bf_array.head(train_rows), bf_array.tail(test_rows)]
77+
for bf_array in bf_arrays
78+
]
79+
)
80+
)
81+
6482
dfs = list(utils.batch_convert_to_dataframe(*arrays))
6583

6684
def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]:

bigframes/ml/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,30 @@ def batch_convert_to_series(
7979
)
8080

8181

82+
def batch_convert_to_bf_equivalent(
83+
*input: ArrayType, session: Optional[Session] = None
84+
) -> Generator[Union[bpd.DataFrame, bpd.Series], None, None]:
85+
"""Converts the input to BigFrames DataFrame or Series.
86+
87+
Args:
88+
session:
89+
The session to convert local pandas instances to BigFrames counter-parts.
90+
It is not used if the input itself is already a BigFrame data frame or series.
91+
92+
"""
93+
_validate_sessions(*input, session=session)
94+
95+
for frame in input:
96+
if isinstance(frame, bpd.DataFrame) or isinstance(frame, pd.DataFrame):
97+
yield convert.to_bf_dataframe(frame, default_index=None, session=session)
98+
elif isinstance(frame, bpd.Series) or isinstance(frame, pd.Series):
99+
yield convert.to_bf_series(
100+
_get_only_column(frame), default_index=None, session=session
101+
)
102+
else:
103+
raise ValueError(f"Unsupported type: {type(frame)}")
104+
105+
82106
def _validate_sessions(*input: ArrayType, session: Optional[Session]):
83107
session_ids = set(
84108
i._session.session_id

tests/system/small/ml/test_model_selection.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
import math
16+
from typing import cast
1617

1718
import pandas as pd
1819
import pytest
1920

2021
from bigframes.ml import model_selection
2122
import bigframes.pandas as bpd
23+
import bigframes.session
2224

2325

2426
@pytest.mark.parametrize(
@@ -219,6 +221,78 @@ def test_train_test_split_seeded_correct_rows(
219221
)
220222

221223

224+
def test_train_test_split_no_shuffle_correct_shape(
225+
penguins_df_default_index: bpd.DataFrame,
226+
):
227+
X = penguins_df_default_index[["species"]]
228+
y = penguins_df_default_index["body_mass_g"]
229+
X_train, X_test, y_train, y_test = model_selection.train_test_split(
230+
X, y, shuffle=False
231+
)
232+
assert isinstance(X_train, bpd.DataFrame)
233+
assert isinstance(X_test, bpd.DataFrame)
234+
assert isinstance(y_train, bpd.Series)
235+
assert isinstance(y_test, bpd.Series)
236+
237+
assert X_train.shape == (258, 1)
238+
assert X_test.shape == (86, 1)
239+
assert y_train.shape == (258,)
240+
assert y_test.shape == (86,)
241+
242+
243+
def test_train_test_split_no_shuffle_correct_rows(
244+
session: bigframes.session.Session, penguins_pandas_df_default_index: bpd.DataFrame
245+
):
246+
# Note that we're using `penguins_pandas_df_default_index` as this test depends
247+
# on a stable row order being present end to end
248+
# filter down to the chunkiest penguins, to keep our test code a reasonable size
249+
all_data = penguins_pandas_df_default_index[
250+
penguins_pandas_df_default_index.body_mass_g > 5500
251+
].sort_index()
252+
253+
# Note that bigframes loses the index if it doesn't have a name
254+
all_data.index.name = "rowindex"
255+
256+
df = session.read_pandas(all_data)
257+
258+
X = df[
259+
[
260+
"species",
261+
"island",
262+
"culmen_length_mm",
263+
]
264+
]
265+
y = df["body_mass_g"]
266+
X_train, X_test, y_train, y_test = model_selection.train_test_split(
267+
X, y, shuffle=False
268+
)
269+
270+
X_train_pd = cast(bpd.DataFrame, X_train).to_pandas()
271+
X_test_pd = cast(bpd.DataFrame, X_test).to_pandas()
272+
y_train_pd = cast(bpd.Series, y_train).to_pandas()
273+
y_test_pd = cast(bpd.Series, y_test).to_pandas()
274+
275+
total_rows = len(all_data)
276+
train_size = 0.75
277+
train_rows = int(total_rows * train_size)
278+
test_rows = total_rows - train_rows
279+
280+
expected_X_train = all_data.head(train_rows)[
281+
["species", "island", "culmen_length_mm"]
282+
]
283+
expected_y_train = all_data.head(train_rows)["body_mass_g"]
284+
285+
expected_X_test = all_data.tail(test_rows)[
286+
["species", "island", "culmen_length_mm"]
287+
]
288+
expected_y_test = all_data.tail(test_rows)["body_mass_g"]
289+
290+
pd.testing.assert_frame_equal(X_train_pd, expected_X_train)
291+
pd.testing.assert_frame_equal(X_test_pd, expected_X_test)
292+
pd.testing.assert_series_equal(y_train_pd, expected_y_train)
293+
pd.testing.assert_series_equal(y_test_pd, expected_y_test)
294+
295+
222296
@pytest.mark.parametrize(
223297
("train_size", "test_size"),
224298
[

0 commit comments

Comments
 (0)