30
30
31
31
_BQML_PARAMS_MAPPING = {
32
32
"booster" : "boosterType" ,
33
+ "dart_normalized_type" : "dartNormalizeType" ,
33
34
"tree_method" : "treeMethod" ,
34
- "colsample_bytree" : "colsampleBylevel " ,
35
- "colsample_bylevel" : "colsampleBytree " ,
35
+ "colsample_bytree" : "colsampleBytree " ,
36
+ "colsample_bylevel" : "colsampleBylevel " ,
36
37
"colsample_bynode" : "colsampleBynode" ,
37
38
"gamma" : "minSplitLoss" ,
38
39
"subsample" : "subsample" ,
44
45
"min_tree_child_weight" : "minTreeChildWeight" ,
45
46
"max_depth" : "maxTreeDepth" ,
46
47
"max_iterations" : "maxIterations" ,
48
+ "enable_global_explain" : "enableGlobalExplain" ,
49
+ "xgboost_version" : "xgboostVersion" ,
47
50
}
48
51
49
52
@@ -99,24 +102,17 @@ def __init__(
99
102
100
103
@classmethod
101
104
def _from_bq (
102
- cls , session : bigframes .Session , model : bigquery .Model
105
+ cls , session : bigframes .Session , bq_model : bigquery .Model
103
106
) -> XGBRegressor :
104
- assert model .model_type == "BOOSTED_TREE_REGRESSOR"
107
+ assert bq_model .model_type == "BOOSTED_TREE_REGRESSOR"
105
108
106
- kwargs = {}
107
-
108
- # See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
109
- last_fitting = model .training_runs [- 1 ]["trainingOptions" ]
110
-
111
- dummy_regressor = cls ()
112
- for bf_param , bf_value in dummy_regressor .__dict__ .items ():
113
- bqml_param = _BQML_PARAMS_MAPPING .get (bf_param )
114
- if bqml_param in last_fitting :
115
- kwargs [bf_param ] = type (bf_value )(last_fitting [bqml_param ])
109
+ kwargs = utils .retrieve_params_from_bq_model (
110
+ cls , bq_model , _BQML_PARAMS_MAPPING
111
+ )
116
112
117
- new_xgb_regressor = cls (** kwargs )
118
- new_xgb_regressor ._bqml_model = core .BqmlModel (session , model )
119
- return new_xgb_regressor
113
+ model = cls (** kwargs )
114
+ model ._bqml_model = core .BqmlModel (session , bq_model )
115
+ return model
120
116
121
117
@property
122
118
def _bqml_options (self ) -> Dict [str , str | int | bool | float | List [str ]]:
@@ -255,24 +251,17 @@ def __init__(
255
251
256
252
@classmethod
257
253
def _from_bq (
258
- cls , session : bigframes .Session , model : bigquery .Model
254
+ cls , session : bigframes .Session , bq_model : bigquery .Model
259
255
) -> XGBClassifier :
260
- assert model .model_type == "BOOSTED_TREE_CLASSIFIER"
256
+ assert bq_model .model_type == "BOOSTED_TREE_CLASSIFIER"
261
257
262
- kwargs = {}
263
-
264
- # See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
265
- last_fitting = model .training_runs [- 1 ]["trainingOptions" ]
266
-
267
- dummy_classifier = XGBClassifier ()
268
- for bf_param , bf_value in dummy_classifier .__dict__ .items ():
269
- bqml_param = _BQML_PARAMS_MAPPING .get (bf_param )
270
- if bqml_param is not None :
271
- kwargs [bf_param ] = type (bf_value )(last_fitting [bqml_param ])
258
+ kwargs = utils .retrieve_params_from_bq_model (
259
+ cls , bq_model , _BQML_PARAMS_MAPPING
260
+ )
272
261
273
- new_xgb_classifier = cls (** kwargs )
274
- new_xgb_classifier ._bqml_model = core .BqmlModel (session , model )
275
- return new_xgb_classifier
262
+ model = cls (** kwargs )
263
+ model ._bqml_model = core .BqmlModel (session , bq_model )
264
+ return model
276
265
277
266
@property
278
267
def _bqml_options (self ) -> Dict [str , str | int | bool | float | List [str ]]:
@@ -370,16 +359,16 @@ def __init__(
370
359
* ,
371
360
tree_method : Literal ["auto" , "exact" , "approx" , "hist" ] = "auto" ,
372
361
min_tree_child_weight : int = 1 ,
373
- colsample_bytree = 1.0 ,
374
- colsample_bylevel = 1.0 ,
375
- colsample_bynode = 0.8 ,
376
- gamma = 0.00 ,
362
+ colsample_bytree : float = 1.0 ,
363
+ colsample_bylevel : float = 1.0 ,
364
+ colsample_bynode : float = 0.8 ,
365
+ gamma : float = 0.0 ,
377
366
max_depth : int = 15 ,
378
- subsample = 0.8 ,
379
- reg_alpha = 0.0 ,
380
- reg_lambda = 1.0 ,
381
- tol = 0.01 ,
382
- enable_global_explain = False ,
367
+ subsample : float = 0.8 ,
368
+ reg_alpha : float = 0.0 ,
369
+ reg_lambda : float = 1.0 ,
370
+ tol : float = 0.01 ,
371
+ enable_global_explain : bool = False ,
383
372
xgboost_version : Literal ["0.9" , "1.1" ] = "0.9" ,
384
373
):
385
374
self .n_estimators = n_estimators
@@ -401,24 +390,17 @@ def __init__(
401
390
402
391
@classmethod
403
392
def _from_bq (
404
- cls , session : bigframes .Session , model : bigquery .Model
393
+ cls , session : bigframes .Session , bq_model : bigquery .Model
405
394
) -> RandomForestRegressor :
406
- assert model .model_type == "RANDOM_FOREST_REGRESSOR"
407
-
408
- kwargs = {}
409
-
410
- # See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
411
- last_fitting = model .training_runs [- 1 ]["trainingOptions" ]
395
+ assert bq_model .model_type == "RANDOM_FOREST_REGRESSOR"
412
396
413
- dummy_model = cls ()
414
- for bf_param , bf_value in dummy_model .__dict__ .items ():
415
- bqml_param = _BQML_PARAMS_MAPPING .get (bf_param )
416
- if bqml_param in last_fitting :
417
- kwargs [bf_param ] = type (bf_value )(last_fitting [bqml_param ])
397
+ kwargs = utils .retrieve_params_from_bq_model (
398
+ cls , bq_model , _BQML_PARAMS_MAPPING
399
+ )
418
400
419
- new_random_forest_regressor = cls (** kwargs )
420
- new_random_forest_regressor ._bqml_model = core .BqmlModel (session , model )
421
- return new_random_forest_regressor
401
+ model = cls (** kwargs )
402
+ model ._bqml_model = core .BqmlModel (session , bq_model )
403
+ return model
422
404
423
405
@property
424
406
def _bqml_options (self ) -> Dict [str , str | int | bool | float | List [str ]]:
@@ -542,7 +524,7 @@ def __init__(
542
524
reg_alpha : float = 0.0 ,
543
525
reg_lambda : float = 1.0 ,
544
526
tol : float = 0.01 ,
545
- enable_global_explain = False ,
527
+ enable_global_explain : bool = False ,
546
528
xgboost_version : Literal ["0.9" , "1.1" ] = "0.9" ,
547
529
):
548
530
self .n_estimators = n_estimators
@@ -564,24 +546,17 @@ def __init__(
564
546
565
547
@classmethod
566
548
def _from_bq (
567
- cls , session : bigframes .Session , model : bigquery .Model
549
+ cls , session : bigframes .Session , bq_model : bigquery .Model
568
550
) -> RandomForestClassifier :
569
- assert model .model_type == "RANDOM_FOREST_CLASSIFIER"
570
-
571
- kwargs = {}
551
+ assert bq_model .model_type == "RANDOM_FOREST_CLASSIFIER"
572
552
573
- # See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
574
- last_fitting = model .training_runs [- 1 ]["trainingOptions" ]
575
-
576
- dummy_model = RandomForestClassifier ()
577
- for bf_param , bf_value in dummy_model .__dict__ .items ():
578
- bqml_param = _BQML_PARAMS_MAPPING .get (bf_param )
579
- if bqml_param is not None :
580
- kwargs [bf_param ] = type (bf_value )(last_fitting [bqml_param ])
553
+ kwargs = utils .retrieve_params_from_bq_model (
554
+ cls , bq_model , _BQML_PARAMS_MAPPING
555
+ )
581
556
582
- new_random_forest_classifier = cls (** kwargs )
583
- new_random_forest_classifier ._bqml_model = core .BqmlModel (session , model )
584
- return new_random_forest_classifier
557
+ model = cls (** kwargs )
558
+ model ._bqml_model = core .BqmlModel (session , bq_model )
559
+ return model
585
560
586
561
@property
587
562
def _bqml_options (self ) -> Dict [str , str | int | bool | float | List [str ]]:
0 commit comments