Skip to content

Commit 5a7b1c9

Browse files
authored
refactor: ml model load read from class type hints (#656)
* refactor: ml model load read from class type hints * exclude unrelated files * fix NoneType * fix tests * fix tests * fix param mappings * fix tests
1 parent ff23b18 commit 5a7b1c9

File tree

10 files changed

+192
-231
lines changed

10 files changed

+192
-231
lines changed

bigframes/ml/cluster.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"distance_type": "distanceType",
3535
"max_iter": "maxIterations",
3636
"tol": "minRelativeProgress",
37+
"warm_start": "warmStart",
3738
}
3839

3940

@@ -67,27 +68,18 @@ def __init__(
6768
self._bqml_model_factory = globals.bqml_model_factory()
6869

6970
@classmethod
70-
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> KMeans:
71-
assert model.model_type == "KMEANS"
71+
def _from_bq(cls, session: bigframes.Session, bq_model: bigquery.Model) -> KMeans:
72+
assert bq_model.model_type == "KMEANS"
7273

7374
kwargs: dict = {}
7475

75-
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
76-
last_fitting = model.training_runs[-1]["trainingOptions"]
77-
dummy_kmeans = cls()
78-
for bf_param, bf_value in dummy_kmeans.__dict__.items():
79-
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
80-
if bqml_param in last_fitting:
81-
# Convert types
82-
kwargs[bf_param] = (
83-
str(last_fitting[bqml_param])
84-
if bf_param in ["init"]
85-
else type(bf_value)(last_fitting[bqml_param])
86-
)
87-
88-
new_kmeans = cls(**kwargs)
89-
new_kmeans._bqml_model = core.BqmlModel(session, model)
90-
return new_kmeans
76+
kwargs = utils.retrieve_params_from_bq_model(
77+
cls, bq_model, _BQML_PARAMS_MAPPING
78+
)
79+
80+
model = cls(**kwargs)
81+
model._bqml_model = core.BqmlModel(session, bq_model)
82+
return model
9183

9284
@property
9385
def _bqml_options(self) -> dict:

bigframes/ml/decomposition.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from bigframes.ml import base, core, globals, utils
2828
import bigframes.pandas as bpd
2929

30+
_BQML_PARAMS_MAPPING = {"svd_solver": "pcaSolver"}
31+
3032

3133
@log_adapter.class_logger
3234
class PCA(
@@ -47,23 +49,22 @@ def __init__(
4749
self._bqml_model_factory = globals.bqml_model_factory()
4850

4951
@classmethod
50-
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> PCA:
51-
assert model.model_type == "PCA"
52+
def _from_bq(cls, session: bigframes.Session, bq_model: bigquery.Model) -> PCA:
53+
assert bq_model.model_type == "PCA"
5254

53-
kwargs: dict = {}
55+
kwargs = utils.retrieve_params_from_bq_model(
56+
cls, bq_model, _BQML_PARAMS_MAPPING
57+
)
5458

55-
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
56-
last_fitting = model.training_runs[-1]["trainingOptions"]
59+
last_fitting = bq_model.training_runs[-1]["trainingOptions"]
5760
if "numPrincipalComponents" in last_fitting:
5861
kwargs["n_components"] = int(last_fitting["numPrincipalComponents"])
59-
if "pcaExplainedVarianceRatio" in last_fitting:
62+
elif "pcaExplainedVarianceRatio" in last_fitting:
6063
kwargs["n_components"] = float(last_fitting["pcaExplainedVarianceRatio"])
61-
if "pcaSolver" in last_fitting:
62-
kwargs["svd_solver"] = str(last_fitting["pcaSolver"])
6364

64-
new_pca = cls(**kwargs)
65-
new_pca._bqml_model = core.BqmlModel(session, model)
66-
return new_pca
65+
model = cls(**kwargs)
66+
model._bqml_model = core.BqmlModel(session, bq_model)
67+
return model
6768

6869
@property
6970
def _bqml_options(self) -> dict:

bigframes/ml/ensemble.py

Lines changed: 47 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030

3131
_BQML_PARAMS_MAPPING = {
3232
"booster": "boosterType",
33+
"dart_normalized_type": "dartNormalizeType",
3334
"tree_method": "treeMethod",
34-
"colsample_bytree": "colsampleBylevel",
35-
"colsample_bylevel": "colsampleBytree",
35+
"colsample_bytree": "colsampleBytree",
36+
"colsample_bylevel": "colsampleBylevel",
3637
"colsample_bynode": "colsampleBynode",
3738
"gamma": "minSplitLoss",
3839
"subsample": "subsample",
@@ -44,6 +45,8 @@
4445
"min_tree_child_weight": "minTreeChildWeight",
4546
"max_depth": "maxTreeDepth",
4647
"max_iterations": "maxIterations",
48+
"enable_global_explain": "enableGlobalExplain",
49+
"xgboost_version": "xgboostVersion",
4750
}
4851

4952

@@ -99,24 +102,17 @@ def __init__(
99102

100103
@classmethod
101104
def _from_bq(
102-
cls, session: bigframes.Session, model: bigquery.Model
105+
cls, session: bigframes.Session, bq_model: bigquery.Model
103106
) -> XGBRegressor:
104-
assert model.model_type == "BOOSTED_TREE_REGRESSOR"
107+
assert bq_model.model_type == "BOOSTED_TREE_REGRESSOR"
105108

106-
kwargs = {}
107-
108-
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
109-
last_fitting = model.training_runs[-1]["trainingOptions"]
110-
111-
dummy_regressor = cls()
112-
for bf_param, bf_value in dummy_regressor.__dict__.items():
113-
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
114-
if bqml_param in last_fitting:
115-
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
109+
kwargs = utils.retrieve_params_from_bq_model(
110+
cls, bq_model, _BQML_PARAMS_MAPPING
111+
)
116112

117-
new_xgb_regressor = cls(**kwargs)
118-
new_xgb_regressor._bqml_model = core.BqmlModel(session, model)
119-
return new_xgb_regressor
113+
model = cls(**kwargs)
114+
model._bqml_model = core.BqmlModel(session, bq_model)
115+
return model
120116

121117
@property
122118
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
@@ -255,24 +251,17 @@ def __init__(
255251

256252
@classmethod
257253
def _from_bq(
258-
cls, session: bigframes.Session, model: bigquery.Model
254+
cls, session: bigframes.Session, bq_model: bigquery.Model
259255
) -> XGBClassifier:
260-
assert model.model_type == "BOOSTED_TREE_CLASSIFIER"
256+
assert bq_model.model_type == "BOOSTED_TREE_CLASSIFIER"
261257

262-
kwargs = {}
263-
264-
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
265-
last_fitting = model.training_runs[-1]["trainingOptions"]
266-
267-
dummy_classifier = XGBClassifier()
268-
for bf_param, bf_value in dummy_classifier.__dict__.items():
269-
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
270-
if bqml_param is not None:
271-
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
258+
kwargs = utils.retrieve_params_from_bq_model(
259+
cls, bq_model, _BQML_PARAMS_MAPPING
260+
)
272261

273-
new_xgb_classifier = cls(**kwargs)
274-
new_xgb_classifier._bqml_model = core.BqmlModel(session, model)
275-
return new_xgb_classifier
262+
model = cls(**kwargs)
263+
model._bqml_model = core.BqmlModel(session, bq_model)
264+
return model
276265

277266
@property
278267
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
@@ -370,16 +359,16 @@ def __init__(
370359
*,
371360
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
372361
min_tree_child_weight: int = 1,
373-
colsample_bytree=1.0,
374-
colsample_bylevel=1.0,
375-
colsample_bynode=0.8,
376-
gamma=0.00,
362+
colsample_bytree: float = 1.0,
363+
colsample_bylevel: float = 1.0,
364+
colsample_bynode: float = 0.8,
365+
gamma: float = 0.0,
377366
max_depth: int = 15,
378-
subsample=0.8,
379-
reg_alpha=0.0,
380-
reg_lambda=1.0,
381-
tol=0.01,
382-
enable_global_explain=False,
367+
subsample: float = 0.8,
368+
reg_alpha: float = 0.0,
369+
reg_lambda: float = 1.0,
370+
tol: float = 0.01,
371+
enable_global_explain: bool = False,
383372
xgboost_version: Literal["0.9", "1.1"] = "0.9",
384373
):
385374
self.n_estimators = n_estimators
@@ -401,24 +390,17 @@ def __init__(
401390

402391
@classmethod
403392
def _from_bq(
404-
cls, session: bigframes.Session, model: bigquery.Model
393+
cls, session: bigframes.Session, bq_model: bigquery.Model
405394
) -> RandomForestRegressor:
406-
assert model.model_type == "RANDOM_FOREST_REGRESSOR"
407-
408-
kwargs = {}
409-
410-
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
411-
last_fitting = model.training_runs[-1]["trainingOptions"]
395+
assert bq_model.model_type == "RANDOM_FOREST_REGRESSOR"
412396

413-
dummy_model = cls()
414-
for bf_param, bf_value in dummy_model.__dict__.items():
415-
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
416-
if bqml_param in last_fitting:
417-
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
397+
kwargs = utils.retrieve_params_from_bq_model(
398+
cls, bq_model, _BQML_PARAMS_MAPPING
399+
)
418400

419-
new_random_forest_regressor = cls(**kwargs)
420-
new_random_forest_regressor._bqml_model = core.BqmlModel(session, model)
421-
return new_random_forest_regressor
401+
model = cls(**kwargs)
402+
model._bqml_model = core.BqmlModel(session, bq_model)
403+
return model
422404

423405
@property
424406
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
@@ -542,7 +524,7 @@ def __init__(
542524
reg_alpha: float = 0.0,
543525
reg_lambda: float = 1.0,
544526
tol: float = 0.01,
545-
enable_global_explain=False,
527+
enable_global_explain: bool = False,
546528
xgboost_version: Literal["0.9", "1.1"] = "0.9",
547529
):
548530
self.n_estimators = n_estimators
@@ -564,24 +546,17 @@ def __init__(
564546

565547
@classmethod
566548
def _from_bq(
567-
cls, session: bigframes.Session, model: bigquery.Model
549+
cls, session: bigframes.Session, bq_model: bigquery.Model
568550
) -> RandomForestClassifier:
569-
assert model.model_type == "RANDOM_FOREST_CLASSIFIER"
570-
571-
kwargs = {}
551+
assert bq_model.model_type == "RANDOM_FOREST_CLASSIFIER"
572552

573-
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
574-
last_fitting = model.training_runs[-1]["trainingOptions"]
575-
576-
dummy_model = RandomForestClassifier()
577-
for bf_param, bf_value in dummy_model.__dict__.items():
578-
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
579-
if bqml_param is not None:
580-
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
553+
kwargs = utils.retrieve_params_from_bq_model(
554+
cls, bq_model, _BQML_PARAMS_MAPPING
555+
)
581556

582-
new_random_forest_classifier = cls(**kwargs)
583-
new_random_forest_classifier._bqml_model = core.BqmlModel(session, model)
584-
return new_random_forest_classifier
557+
model = cls(**kwargs)
558+
model._bqml_model = core.BqmlModel(session, bq_model)
559+
return model
585560

586561
@property
587562
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:

bigframes/ml/forecasting.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"auto_arima_min_order": "autoArimaMinOrder",
3333
"order": "nonSeasonalOrder",
3434
"data_frequency": "dataFrequency",
35+
"include_drift": "includeDrift",
3536
"holiday_region": "holidayRegion",
3637
"clean_spikes_and_dips": "cleanSpikesAndDips",
3738
"adjust_step_changes": "adjustStepChanges",
@@ -131,35 +132,18 @@ def __init__(
131132
self._bqml_model_factory = globals.bqml_model_factory()
132133

133134
@classmethod
134-
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> ARIMAPlus:
135-
assert model.model_type == "ARIMA_PLUS"
136-
137-
kwargs: dict = {}
138-
last_fitting = model.training_runs[-1]["trainingOptions"]
139-
140-
dummy_arima = cls()
141-
for bf_param, bf_value in dummy_arima.__dict__.items():
142-
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
143-
if bqml_param in last_fitting:
144-
# Convert types
145-
if bf_param in ["time_series_length_fraction"]:
146-
kwargs[bf_param] = float(last_fitting[bqml_param])
147-
elif bf_param in [
148-
"auto_arima_max_order",
149-
"auto_arima_min_order",
150-
"min_time_series_length",
151-
"max_time_series_length",
152-
"trend_smoothing_window_size",
153-
]:
154-
kwargs[bf_param] = int(last_fitting[bqml_param])
155-
elif bf_param in ["holiday_region"]:
156-
kwargs[bf_param] = str(last_fitting[bqml_param])
157-
else:
158-
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
159-
160-
new_arima_plus = cls(**kwargs)
161-
new_arima_plus._bqml_model = core.BqmlModel(session, model)
162-
return new_arima_plus
135+
def _from_bq(
136+
cls, session: bigframes.Session, bq_model: bigquery.Model
137+
) -> ARIMAPlus:
138+
assert bq_model.model_type == "ARIMA_PLUS"
139+
140+
kwargs = utils.retrieve_params_from_bq_model(
141+
cls, bq_model, _BQML_PARAMS_MAPPING
142+
)
143+
144+
model = cls(**kwargs)
145+
model._bqml_model = core.BqmlModel(session, bq_model)
146+
return model
163147

164148
@property
165149
def _bqml_options(self) -> dict:

bigframes/ml/imported.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ def _create_bqml_model(self):
5656

5757
@classmethod
5858
def _from_bq(
59-
cls, session: bigframes.Session, model: bigquery.Model
59+
cls, session: bigframes.Session, bq_model: bigquery.Model
6060
) -> TensorFlowModel:
61-
assert model.model_type == "TENSORFLOW"
61+
assert bq_model.model_type == "TENSORFLOW"
6262

63-
tf_model = cls(session=session, model_path="")
64-
tf_model._bqml_model = core.BqmlModel(session, model)
65-
return tf_model
63+
model = cls(session=session, model_path="")
64+
model._bqml_model = core.BqmlModel(session, bq_model)
65+
return model
6666

6767
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
6868
"""Predict the result from input DataFrame.
@@ -134,12 +134,14 @@ def _create_bqml_model(self):
134134
)
135135

136136
@classmethod
137-
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> ONNXModel:
138-
assert model.model_type == "ONNX"
137+
def _from_bq(
138+
cls, session: bigframes.Session, bq_model: bigquery.Model
139+
) -> ONNXModel:
140+
assert bq_model.model_type == "ONNX"
139141

140-
onnx_model = cls(session=session, model_path="")
141-
onnx_model._bqml_model = core.BqmlModel(session, model)
142-
return onnx_model
142+
model = cls(session=session, model_path="")
143+
model._bqml_model = core.BqmlModel(session, bq_model)
144+
return model
143145

144146
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
145147
"""Predict the result from input DataFrame.
@@ -249,13 +251,13 @@ def _create_bqml_model(self):
249251

250252
@classmethod
251253
def _from_bq(
252-
cls, session: bigframes.Session, model: bigquery.Model
254+
cls, session: bigframes.Session, bq_model: bigquery.Model
253255
) -> XGBoostModel:
254-
assert model.model_type == "XGBOOST"
256+
assert bq_model.model_type == "XGBOOST"
255257

256-
xgboost_model = cls(session=session, model_path="")
257-
xgboost_model._bqml_model = core.BqmlModel(session, model)
258-
return xgboost_model
258+
model = cls(session=session, model_path="")
259+
model._bqml_model = core.BqmlModel(session, bq_model)
260+
return model
259261

260262
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
261263
"""Predict the result from input DataFrame.

0 commit comments

Comments
 (0)