Skip to content

Commit 254875c

Browse files
authored
feat: allow fit to take additional eval data in linear and ensemble models (#1096)
* feat: allow `fit` to take additional eval data The additional eval data would be used to measure the fitted model and attach the measurement to the underlying BQML model, which can be used as benchmark for the model consumers in BigQuery Studio and otherwise. * subclass from TrainablePredictor * add support for fit-time evaluation in ensemble models * fetch logistic regression eval numbers from multiClassClassificationMetrics * use the generic type template * update vendored docstrings for fit taking X_eval, y_eval * update key to fetch model eval metrics * enfore binary classification in the logistic regression test
1 parent 5f7b8b1 commit 254875c

File tree

9 files changed

+327
-15
lines changed

9 files changed

+327
-15
lines changed

bigframes/ml/base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,40 @@ def fit(
164164
return self._fit(X, y)
165165

166166

167+
class TrainableWithEvaluationPredictor(TrainablePredictor):
168+
"""A BigQuery DataFrames ML Model base class that can be used to fit and predict outputs.
169+
170+
Additional evaluation data can be provided to measure the model in the fit phase."""
171+
172+
@abc.abstractmethod
173+
def _fit(self, X, y, transforms=None, X_eval=None, y_eval=None):
174+
pass
175+
176+
@abc.abstractmethod
177+
def score(self, X, y):
178+
pass
179+
180+
181+
class SupervisedTrainableWithEvaluationPredictor(TrainableWithEvaluationPredictor):
182+
"""A BigQuery DataFrames ML Supervised Model base class that can be used to fit and predict outputs.
183+
184+
Need to provide both X and y in supervised tasks.
185+
186+
Additional X_eval and y_eval can be provided to measure the model in the fit phase.
187+
"""
188+
189+
_T = TypeVar("_T", bound="SupervisedTrainableWithEvaluationPredictor")
190+
191+
def fit(
192+
self: _T,
193+
X: utils.ArrayType,
194+
y: utils.ArrayType,
195+
X_eval: Optional[utils.ArrayType] = None,
196+
y_eval: Optional[utils.ArrayType] = None,
197+
) -> _T:
198+
return self._fit(X, y, X_eval=X_eval, y_eval=y_eval)
199+
200+
167201
class UnsupervisedTrainablePredictor(TrainablePredictor):
168202
"""A BigQuery DataFrames ML Unsupervised Model base class that can be used to fit and predict outputs.
169203

bigframes/ml/ensemble.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
@log_adapter.class_logger
5454
class XGBRegressor(
55-
base.SupervisedTrainablePredictor,
55+
base.SupervisedTrainableWithEvaluationPredictor,
5656
bigframes_vendored.xgboost.sklearn.XGBRegressor,
5757
):
5858
__doc__ = bigframes_vendored.xgboost.sklearn.XGBRegressor.__doc__
@@ -145,14 +145,24 @@ def _fit(
145145
X: utils.ArrayType,
146146
y: utils.ArrayType,
147147
transforms: Optional[List[str]] = None,
148+
X_eval: Optional[utils.ArrayType] = None,
149+
y_eval: Optional[utils.ArrayType] = None,
148150
) -> XGBRegressor:
149151
X, y = utils.convert_to_dataframe(X, y)
150152

153+
bqml_options = self._bqml_options
154+
155+
if X_eval is not None and y_eval is not None:
156+
X_eval, y_eval = utils.convert_to_dataframe(X_eval, y_eval)
157+
X, y, bqml_options = utils.combine_training_and_evaluation_data(
158+
X, y, X_eval, y_eval, bqml_options
159+
)
160+
151161
self._bqml_model = self._bqml_model_factory.create_model(
152162
X,
153163
y,
154164
transforms=transforms,
155-
options=self._bqml_options,
165+
options=bqml_options,
156166
)
157167
return self
158168

@@ -200,7 +210,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBRegressor:
200210

201211
@log_adapter.class_logger
202212
class XGBClassifier(
203-
base.SupervisedTrainablePredictor,
213+
base.SupervisedTrainableWithEvaluationPredictor,
204214
bigframes_vendored.xgboost.sklearn.XGBClassifier,
205215
):
206216

@@ -294,14 +304,24 @@ def _fit(
294304
X: utils.ArrayType,
295305
y: utils.ArrayType,
296306
transforms: Optional[List[str]] = None,
307+
X_eval: Optional[utils.ArrayType] = None,
308+
y_eval: Optional[utils.ArrayType] = None,
297309
) -> XGBClassifier:
298310
X, y = utils.convert_to_dataframe(X, y)
299311

312+
bqml_options = self._bqml_options
313+
314+
if X_eval is not None and y_eval is not None:
315+
X_eval, y_eval = utils.convert_to_dataframe(X_eval, y_eval)
316+
X, y, bqml_options = utils.combine_training_and_evaluation_data(
317+
X, y, X_eval, y_eval, bqml_options
318+
)
319+
300320
self._bqml_model = self._bqml_model_factory.create_model(
301321
X,
302322
y,
303323
transforms=transforms,
304-
options=self._bqml_options,
324+
options=bqml_options,
305325
)
306326
return self
307327

@@ -347,7 +367,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBClassifier:
347367

348368
@log_adapter.class_logger
349369
class RandomForestRegressor(
350-
base.SupervisedTrainablePredictor,
370+
base.SupervisedTrainableWithEvaluationPredictor,
351371
bigframes_vendored.sklearn.ensemble._forest.RandomForestRegressor,
352372
):
353373

@@ -430,14 +450,24 @@ def _fit(
430450
X: utils.ArrayType,
431451
y: utils.ArrayType,
432452
transforms: Optional[List[str]] = None,
453+
X_eval: Optional[utils.ArrayType] = None,
454+
y_eval: Optional[utils.ArrayType] = None,
433455
) -> RandomForestRegressor:
434456
X, y = utils.convert_to_dataframe(X, y)
435457

458+
bqml_options = self._bqml_options
459+
460+
if X_eval is not None and y_eval is not None:
461+
X_eval, y_eval = utils.convert_to_dataframe(X_eval, y_eval)
462+
X, y, bqml_options = utils.combine_training_and_evaluation_data(
463+
X, y, X_eval, y_eval, bqml_options
464+
)
465+
436466
self._bqml_model = self._bqml_model_factory.create_model(
437467
X,
438468
y,
439469
transforms=transforms,
440-
options=self._bqml_options,
470+
options=bqml_options,
441471
)
442472
return self
443473

@@ -503,7 +533,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> RandomForestRegresso
503533

504534
@log_adapter.class_logger
505535
class RandomForestClassifier(
506-
base.SupervisedTrainablePredictor,
536+
base.SupervisedTrainableWithEvaluationPredictor,
507537
bigframes_vendored.sklearn.ensemble._forest.RandomForestClassifier,
508538
):
509539

@@ -586,14 +616,24 @@ def _fit(
586616
X: utils.ArrayType,
587617
y: utils.ArrayType,
588618
transforms: Optional[List[str]] = None,
619+
X_eval: Optional[utils.ArrayType] = None,
620+
y_eval: Optional[utils.ArrayType] = None,
589621
) -> RandomForestClassifier:
590622
X, y = utils.convert_to_dataframe(X, y)
591623

624+
bqml_options = self._bqml_options
625+
626+
if X_eval is not None and y_eval is not None:
627+
X_eval, y_eval = utils.convert_to_dataframe(X_eval, y_eval)
628+
X, y, bqml_options = utils.combine_training_and_evaluation_data(
629+
X, y, X_eval, y_eval, bqml_options
630+
)
631+
592632
self._bqml_model = self._bqml_model_factory.create_model(
593633
X,
594634
y,
595635
transforms=transforms,
596-
options=self._bqml_options,
636+
options=bqml_options,
597637
)
598638
return self
599639

bigframes/ml/linear_model.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
@log_adapter.class_logger
4949
class LinearRegression(
50-
base.SupervisedTrainablePredictor,
50+
base.SupervisedTrainableWithEvaluationPredictor,
5151
bigframes_vendored.sklearn.linear_model._base.LinearRegression,
5252
):
5353
__doc__ = bigframes_vendored.sklearn.linear_model._base.LinearRegression.__doc__
@@ -131,14 +131,24 @@ def _fit(
131131
X: utils.ArrayType,
132132
y: utils.ArrayType,
133133
transforms: Optional[List[str]] = None,
134+
X_eval: Optional[utils.ArrayType] = None,
135+
y_eval: Optional[utils.ArrayType] = None,
134136
) -> LinearRegression:
135137
X, y = utils.convert_to_dataframe(X, y)
136138

139+
bqml_options = self._bqml_options
140+
141+
if X_eval is not None and y_eval is not None:
142+
X_eval, y_eval = utils.convert_to_dataframe(X_eval, y_eval)
143+
X, y, bqml_options = utils.combine_training_and_evaluation_data(
144+
X, y, X_eval, y_eval, bqml_options
145+
)
146+
137147
self._bqml_model = self._bqml_model_factory.create_model(
138148
X,
139149
y,
140150
transforms=transforms,
141-
options=self._bqml_options,
151+
options=bqml_options,
142152
)
143153
return self
144154

@@ -183,7 +193,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> LinearRegression:
183193

184194
@log_adapter.class_logger
185195
class LogisticRegression(
186-
base.SupervisedTrainablePredictor,
196+
base.SupervisedTrainableWithEvaluationPredictor,
187197
bigframes_vendored.sklearn.linear_model._logistic.LogisticRegression,
188198
):
189199
__doc__ = (
@@ -283,15 +293,24 @@ def _fit(
283293
X: utils.ArrayType,
284294
y: utils.ArrayType,
285295
transforms: Optional[List[str]] = None,
296+
X_eval: Optional[utils.ArrayType] = None,
297+
y_eval: Optional[utils.ArrayType] = None,
286298
) -> LogisticRegression:
287-
"""Fit model with transforms."""
288299
X, y = utils.convert_to_dataframe(X, y)
289300

301+
bqml_options = self._bqml_options
302+
303+
if X_eval is not None and y_eval is not None:
304+
X_eval, y_eval = utils.convert_to_dataframe(X_eval, y_eval)
305+
X, y, bqml_options = utils.combine_training_and_evaluation_data(
306+
X, y, X_eval, y_eval, bqml_options
307+
)
308+
290309
self._bqml_model = self._bqml_model_factory.create_model(
291310
X,
292311
y,
293312
transforms=transforms,
294-
options=self._bqml_options,
313+
options=bqml_options,
295314
)
296315
return self
297316

bigframes/ml/utils.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414

1515
import typing
16-
from typing import Any, Generator, Literal, Mapping, Optional, Union
16+
from typing import Any, Generator, Literal, Mapping, Optional, Tuple, Union
1717

1818
import bigframes_vendored.constants as constants
1919
from google.cloud import bigquery
2020
import pandas as pd
2121

22-
from bigframes.core import blocks
22+
from bigframes.core import blocks, guid
2323
import bigframes.pandas as bpd
2424
from bigframes.session import Session
2525

@@ -155,3 +155,37 @@ def retrieve_params_from_bq_model(
155155
kwargs[bf_param] = bf_param_type(last_fitting[bqml_param])
156156

157157
return kwargs
158+
159+
160+
def combine_training_and_evaluation_data(
161+
X_train: bpd.DataFrame,
162+
y_train: bpd.DataFrame,
163+
X_eval: bpd.DataFrame,
164+
y_eval: bpd.DataFrame,
165+
bqml_options: dict,
166+
) -> Tuple[bpd.DataFrame, bpd.DataFrame, dict]:
167+
"""
168+
Combine training data and labels with evlauation data and labels, and keep
169+
them differentiated through a split column in the combined data and labels.
170+
"""
171+
172+
assert X_train.columns.equals(X_eval.columns)
173+
assert y_train.columns.equals(y_eval.columns)
174+
175+
# create a custom split column for BQML and supply the evaluation
176+
# data along with the training data in a combined single table
177+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-dnn-models#data_split_col.
178+
split_col = guid.generate_guid()
179+
assert split_col not in X_train.columns
180+
181+
X_train[split_col] = False
182+
X_eval[split_col] = True
183+
X = bpd.concat([X_train, X_eval])
184+
y = bpd.concat([y_train, y_eval])
185+
186+
# create options copy to not mutate the incoming one
187+
bqml_options = bqml_options.copy()
188+
bqml_options["data_split_method"] = "CUSTOM"
189+
bqml_options["data_split_col"] = split_col
190+
191+
return X, y, bqml_options

0 commit comments

Comments
 (0)