Skip to content

Commit afa7cc4

Browse files
authored
refactor: use core.convert for series conversions under the ml packages (#1178)
* refactor: use core.convert for series conversions under the ml packages * update method name * fetch global session lazily
1 parent 557ab8d commit afa7cc4

File tree

6 files changed

+48
-49
lines changed

6 files changed

+48
-49
lines changed

bigframes/core/convert.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pandas as pd
1919

20+
from bigframes.core import global_session
2021
import bigframes.core.indexes as index
2122
import bigframes.series as series
2223

@@ -36,7 +37,9 @@ def is_series_convertible(obj) -> bool:
3637
return False
3738

3839

39-
def to_bf_series(obj, default_index: Optional[index.Index], session) -> series.Series:
40+
def to_bf_series(
41+
obj, default_index: Optional[index.Index], session=None
42+
) -> series.Series:
4043
"""
4144
Convert a an object to a bigframes series
4245
@@ -51,6 +54,10 @@ def to_bf_series(obj, default_index: Optional[index.Index], session) -> series.S
5154
"""
5255
if isinstance(obj, series.Series):
5356
return obj
57+
58+
if session is None:
59+
session = global_session.get_global_session()
60+
5461
if isinstance(obj, pd.Series):
5562
return series.Series(obj, session=session)
5663
if isinstance(obj, index.Index):

bigframes/ml/metrics/_metrics.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def r2_score(
3737
*,
3838
force_finite=True,
3939
) -> float:
40-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
40+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
4141

4242
# total sum of squares
4343
# (dataframe, scalar) binops
@@ -66,7 +66,7 @@ def accuracy_score(
6666
normalize=True,
6767
) -> float:
6868
# TODO(ashleyxu): support sample_weight as the parameter
69-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
69+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
7070

7171
# Compute accuracy for each possible representation
7272
# TODO(ashleyxu): add multilabel classification support where y_type
@@ -97,7 +97,7 @@ def roc_curve(
9797
f"drop_intermediate is not yet implemented. {constants.FEEDBACK_LINK}"
9898
)
9999

100-
y_true_series, y_score_series = utils.convert_to_series(y_true, y_score)
100+
y_true_series, y_score_series = utils.batch_convert_to_series(y_true, y_score)
101101

102102
session = y_true_series._block.expr.session
103103

@@ -157,7 +157,7 @@ def roc_auc_score(
157157
) -> float:
158158
# TODO(bmil): Add multi-class support
159159
# TODO(bmil): Add multi-label support
160-
y_true_series, y_score_series = utils.convert_to_series(y_true, y_score)
160+
y_true_series, y_score_series = utils.batch_convert_to_series(y_true, y_score)
161161

162162
fpr, tpr, _ = roc_curve(y_true_series, y_score_series, drop_intermediate=False)
163163

@@ -174,7 +174,7 @@ def auc(
174174
x: Union[bpd.DataFrame, bpd.Series],
175175
y: Union[bpd.DataFrame, bpd.Series],
176176
) -> float:
177-
x_series, y_series = utils.convert_to_series(x, y)
177+
x_series, y_series = utils.batch_convert_to_series(x, y)
178178

179179
# TODO(b/286410053) Support ML exceptions and error handling.
180180
auc = sklearn_metrics.auc(x_series.to_pandas(), y_series.to_pandas())
@@ -189,7 +189,7 @@ def confusion_matrix(
189189
y_pred: Union[bpd.DataFrame, bpd.Series],
190190
) -> pd.DataFrame:
191191
# TODO(ashleyxu): support labels and sample_weight parameters
192-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
192+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
193193

194194
y_true_series = y_true_series.rename("y_true")
195195
confusion_df = y_true_series.to_frame().assign(y_pred=y_pred_series)
@@ -235,7 +235,7 @@ def recall_score(
235235
f"Only average=None is supported. {constants.FEEDBACK_LINK}"
236236
)
237237

238-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
238+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
239239

240240
is_accurate = y_true_series == y_pred_series
241241
unique_labels = (
@@ -272,7 +272,7 @@ def precision_score(
272272
f"Only average=None is supported. {constants.FEEDBACK_LINK}"
273273
)
274274

275-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
275+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
276276

277277
is_accurate = y_true_series == y_pred_series
278278
unique_labels = (
@@ -306,7 +306,7 @@ def f1_score(
306306
average: typing.Optional[str] = "binary",
307307
) -> pd.Series:
308308
# TODO(ashleyxu): support more average type, default to "binary"
309-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
309+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
310310

311311
if average is not None:
312312
raise NotImplementedError(
@@ -337,7 +337,7 @@ def mean_squared_error(
337337
y_true: Union[bpd.DataFrame, bpd.Series],
338338
y_pred: Union[bpd.DataFrame, bpd.Series],
339339
) -> float:
340-
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
340+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
341341

342342
return (y_pred_series - y_true_series).pow(2).sum() / len(y_true_series)
343343

bigframes/ml/metrics/pairwise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
def paired_cosine_distances(
2626
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
2727
) -> bpd.DataFrame:
28-
X, Y = utils.convert_to_series(X, Y)
28+
X, Y = utils.batch_convert_to_series(X, Y)
2929
joined_block, _ = X._block.join(Y._block, how="outer")
3030

3131
result_block, _ = joined_block.project_expr(
@@ -45,7 +45,7 @@ def paired_cosine_distances(
4545
def paired_manhattan_distance(
4646
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
4747
) -> bpd.DataFrame:
48-
X, Y = utils.convert_to_series(X, Y)
48+
X, Y = utils.batch_convert_to_series(X, Y)
4949
joined_block, _ = X._block.join(Y._block, how="outer")
5050

5151
result_block, _ = joined_block.project_expr(
@@ -65,7 +65,7 @@ def paired_manhattan_distance(
6565
def paired_euclidean_distances(
6666
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
6767
) -> bpd.DataFrame:
68-
X, Y = utils.convert_to_series(X, Y)
68+
X, Y = utils.batch_convert_to_series(X, Y)
6969
joined_block, _ = X._block.join(Y._block, how="outer")
7070

7171
result_block, _ = joined_block.project_expr(

bigframes/ml/model_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _convert_to_bf_type(
162162
if isinstance(type_instance, pd.Series) or isinstance(
163163
type_instance, bpd.Series
164164
):
165-
return next(utils.convert_to_series(input))
165+
return next(utils.batch_convert_to_series(input))
166166

167167
if isinstance(type_instance, pd.DataFrame) or isinstance(
168168
type_instance, bpd.DataFrame

bigframes/ml/utils.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414

1515
import typing
16-
from typing import Any, Generator, Literal, Mapping, Optional, Tuple, Union
16+
from typing import Any, Generator, Hashable, Literal, Mapping, Optional, Tuple, Union
1717

1818
import bigframes_vendored.constants as constants
1919
from google.cloud import bigquery
2020
import pandas as pd
2121

22-
from bigframes.core import blocks, guid
22+
from bigframes.core import convert, guid
2323
import bigframes.pandas as bpd
2424
from bigframes.session import Session
2525

@@ -65,7 +65,7 @@ def _convert_to_dataframe(
6565
)
6666

6767

68-
def convert_to_series(
68+
def batch_convert_to_series(
6969
*input: ArrayType, session: Optional[Session] = None
7070
) -> Generator[bpd.Series, None, None]:
7171
"""Converts the input to BigFrames Series.
@@ -76,37 +76,29 @@ def convert_to_series(
7676
It is not used if the input itself is already a BigFrame data frame or series.
7777
7878
"""
79-
return (_convert_to_series(frame, session) for frame in input)
79+
return (
80+
convert.to_bf_series(
81+
_get_only_column(frame), default_index=None, session=session
82+
)
83+
for frame in input
84+
)
8085

8186

82-
def _convert_to_series(
83-
frame: ArrayType, session: Optional[Session] = None
84-
) -> bpd.Series:
85-
if isinstance(frame, bpd.DataFrame):
86-
if len(frame.columns) != 1:
87-
raise ValueError(
88-
"To convert into Series, DataFrames can only contain one column. "
89-
f"Try input with only one column. {constants.FEEDBACK_LINK}"
90-
)
91-
92-
label = typing.cast(blocks.Label, frame.columns.tolist()[0])
93-
return typing.cast(bpd.Series, frame[label])
94-
if isinstance(frame, bpd.Series):
95-
return frame
96-
if isinstance(frame, pd.DataFrame):
97-
# Recursively call this method to re-use the length-checking logic
98-
if session is None:
99-
return _convert_to_series(bpd.read_pandas(frame))
100-
else:
101-
return _convert_to_series(session.read_pandas(frame), session)
102-
if isinstance(frame, pd.Series):
103-
if session is None:
104-
return bpd.read_pandas(frame)
105-
else:
106-
return session.read_pandas(frame)
107-
raise ValueError(
108-
f"Unsupported type {type(frame)} to convert to Series. {constants.FEEDBACK_LINK}"
109-
)
87+
def _get_only_column(input: ArrayType) -> Union[pd.Series, bpd.Series]:
88+
if isinstance(input, pd.Series) or isinstance(input, bpd.Series):
89+
return input
90+
91+
if len(input.columns) != 1:
92+
raise ValueError(
93+
"To convert into Series, DataFrames can only contain one column. "
94+
f"Try input with only one column. {constants.FEEDBACK_LINK}"
95+
)
96+
97+
label = typing.cast(Hashable, input.columns.tolist()[0])
98+
if isinstance(input, pd.DataFrame):
99+
return typing.cast(pd.Series, input[label])
100+
101+
return typing.cast(bpd.Series, input[label])
110102

111103

112104
def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:

tests/system/small/ml/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_convert_pandas_to_dataframe(data, session):
6161
def test_convert_to_series(session, data):
6262
bf_data = session.read_pandas(data)
6363

64-
(actual_result,) = utils.convert_to_series(bf_data)
64+
(actual_result,) = utils.batch_convert_to_series(bf_data)
6565

6666
pandas.testing.assert_series_equal(
6767
actual_result.to_pandas(), _SERIES, check_index_type=False, check_dtype=False
@@ -73,7 +73,7 @@ def test_convert_to_series(session, data):
7373
[pytest.param(_DATA_FRAME, id="dataframe"), pytest.param(_SERIES, id="series")],
7474
)
7575
def test_convert_pandas_to_series(data, session):
76-
(actual_result,) = utils.convert_to_series(data, session=session)
76+
(actual_result,) = utils.batch_convert_to_series(data, session=session)
7777

7878
pandas.testing.assert_series_equal(
7979
actual_result.to_pandas(), _SERIES, check_index_type=False, check_dtype=False

0 commit comments

Comments
 (0)