Skip to content

Commit 101d501

Browse files
committed
add argument base_model_method for predict function
1 parent d544262 commit 101d501

File tree

4 files changed

+124
-14
lines changed

4 files changed

+124
-14
lines changed

stemflow/model/AdaSTEM.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,11 @@ def stixel_predict(self, stixel: pd.core.frame.DataFrame) -> Union[None, pd.core
672672
if model_x_names_tuple[0] is None:
673673
return None
674674

675-
pred = predict_one_stixel(stixel, self.task, model_x_names_tuple, **self.base_model_prediction_param)
675+
pred = predict_one_stixel(X_test_stixel=stixel,
676+
task=self.task,
677+
model_x_names_tuple=model_x_names_tuple,
678+
base_model_method=self.base_model_method,
679+
**self.base_model_prediction_param)
676680

677681
if pred is None:
678682
return None
@@ -814,6 +818,7 @@ def predict_proba(
814818
aggregation: str = "mean",
815819
return_by_separate_ensembles: bool = False,
816820
logit_agg: bool = False,
821+
base_model_method: Union[None, str] = None,
817822
**base_model_prediction_param
818823
) -> Union[np.ndarray, Tuple[np.ndarray]]:
819824
"""Predict probability
@@ -836,7 +841,11 @@ def predict_proba(
836841
return_by_separate_ensembles (bool, optional):
837842
Experimental function. return not by aggregation, but by separate ensembles.
838843
logit_agg:
839-
Whether to use logit aggregation for the classification task. If True, the model is averaging the probability prediction estimated by all ensembles in logit scale, and then back-tranforms it to probability scale. It's recommended to be jointly used with the CalibratedClassifierCV class in sklearn as a wrapper of the classifier to estimate the calibrated probability. If False, the output is essentially the proportion of "1s" across the related ensembles; e.g., if 100 stixels covers this spatiotemporal points, and 90% of them predict that it is a "1", then the output probability is 0.9; Therefore it would be a probability estimated by the spatiotemporal neighborhood. Default is False, but can be set to truth for "real" probability averaging.
844+
Whether to use logit aggregation for the classification task. Most likely only used when you are predicting "real" calibrated probability. If True, the model is averaging the probability prediction estimated by all ensembles in logit scale, and then back-tranforms it to probability scale. It's recommended to be jointly used with the CalibratedClassifierCV class in sklearn as a wrapper of the classifier to estimate the calibrated probability. Default is False, but can be set to true for "real" probability averaging.
845+
base_model_method:
846+
The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function. Notice that dummy model will still predict 0, so the ensemble-aggregated result is still an average of zeros and your special prediction function output. Therefore, it may only make sense if your special prediction function predicts 0 as the absense/control value. Defaults to None.
847+
base_model_prediction_param:
848+
Any other paramters to pass into the prediction method of the base models. e.g., base_model_prediction_param={'n_jobs':1}.
840849
Raises:
841850
TypeError:
842851
X_test is not of type pd.core.frame.DataFrame.
@@ -855,6 +864,7 @@ def predict_proba(
855864
return_by_separate_ensembles, return_std = check_prediction_return(return_by_separate_ensembles, return_std)
856865
verbosity = check_verbosity(self, verbosity)
857866
n_jobs = check_transform_n_jobs(self, n_jobs)
867+
self.base_model_method = base_model_method
858868
self.base_model_prediction_param = base_model_prediction_param
859869

860870
# predict
@@ -889,7 +899,7 @@ def predict_proba(
889899
res_mean = res.mean(axis=1, skipna=True) # mean of all grid model that predicts this stixel
890900
elif aggregation == "median":
891901
res_mean = res.median(axis=1, skipna=True)
892-
902+
893903
res_std = res.std(axis=1, skipna=True)
894904

895905
# Nan count
@@ -935,6 +945,7 @@ def predict(
935945
aggregation: str = "mean",
936946
return_by_separate_ensembles: bool = False,
937947
logit_agg: bool = False,
948+
base_model_method: Union[None, str] = None,
938949
**base_model_prediction_param
939950
) -> Union[np.ndarray, Tuple[np.ndarray]]:
940951
pass
@@ -1406,6 +1417,7 @@ def predict(
14061417
aggregation: str = "mean",
14071418
return_by_separate_ensembles: bool = False,
14081419
logit_agg: bool = False,
1420+
base_model_method: Union[None, str] = None,
14091421
**base_model_prediction_param
14101422
) -> Union[np.ndarray, Tuple[np.ndarray]]:
14111423
"""A rewrite of predict_proba adapted for Classifier
@@ -1431,10 +1443,12 @@ def predict(
14311443
'mean' or 'median' for aggregation method across ensembles.
14321444
return_by_separate_ensembles (bool, optional):
14331445
Experimental function. return not by aggregation, but by separate ensembles.
1434-
base_model_prediction_param:
1435-
Additional parameter passed to base_model.predict_proba or base_model.predict
14361446
logit_agg:
1437-
Whether to use logit aggregation for the classification task. If True, the model is averaging the probability prediction estimated by all ensembles in logit scale, and then back-tranform it to probability scale. It's recommened to be combinedly used with the CalibratedClassifierCV class in sklearn as a wrapper of the classifier to estimate the calibrated probability. If False, the output is the essentially the proportion of "1s" acorss the related ensembles; e.g., if 100 stixels covers this spatiotemporal points, and 90% of them predict that it is a "1", then the ouput probability is 0.9; Therefore it would be a probability estimated by the spatiotemporal neiborhood.
1447+
Whether to use logit aggregation for the classification task. If True, the model is averaging the probability prediction estimated by all ensembles in logit scale, and then back-tranform it to probability scale. It's recommened to be combinedly used with the CalibratedClassifierCV class in sklearn as a wrapper of the classifier to estimate the calibrated probability.
1448+
base_model_method:
1449+
The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function. Defaults to None.
1450+
base_model_prediction_param:
1451+
Any other paramters to pass into the prediction method of the base models. e.g., base_model_prediction_param={'n_jobs':1}.
14381452
Raises:
14391453
TypeError:
14401454
X_test is not of type pd.core.frame.DataFrame.
@@ -1457,6 +1471,7 @@ def predict(
14571471
aggregation=aggregation,
14581472
return_by_separate_ensembles=return_by_separate_ensembles,
14591473
logit_agg=logit_agg,
1474+
base_model_method=base_model_method,
14601475
**base_model_prediction_param
14611476
)
14621477
mean = mean[:,1].flatten()
@@ -1473,6 +1488,7 @@ def predict(
14731488
aggregation=aggregation,
14741489
return_by_separate_ensembles=return_by_separate_ensembles,
14751490
logit_agg=logit_agg,
1491+
base_model_method=base_model_method,
14761492
**base_model_prediction_param
14771493
)
14781494
mean = mean[:,1].flatten()
@@ -1588,6 +1604,7 @@ def predict(
15881604
n_jobs: Union[None, int] = 1,
15891605
aggregation: str = "mean",
15901606
return_by_separate_ensembles: bool = False,
1607+
base_model_method: Union[None, str] = None,
15911608
**base_model_prediction_param
15921609
) -> Union[np.ndarray, Tuple[np.ndarray]]:
15931610
"""A rewrite of predict_proba
@@ -1609,8 +1626,10 @@ def predict(
16091626
'mean' or 'median' for aggregation method across ensembles.
16101627
return_by_separate_ensembles (bool, optional):
16111628
Experimental function. return not by aggregation, but by separate ensembles.
1629+
base_model_method:
1630+
The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function. Defaults to None.
16121631
base_model_prediction_param:
1613-
Additional parameter passed to base_model.predict_proba or base_model.predict
1632+
Any other paramters to pass into the prediction method of the base models. e.g., base_model_prediction_param={'n_jobs':1}.
16141633
16151634
Raises:
16161635
TypeError:
@@ -1633,6 +1652,7 @@ def predict(
16331652
n_jobs=n_jobs,
16341653
aggregation=aggregation,
16351654
return_by_separate_ensembles=return_by_separate_ensembles,
1655+
base_model_method = base_model_method,
16361656
**base_model_prediction_param
16371657
)
16381658

stemflow/model/static_func_AdaSTEM.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def predict_one_stixel(
431431
X_test_stixel: pd.core.frame.DataFrame,
432432
task: str,
433433
model_x_names_tuple: Tuple[Union[None, BaseEstimator], list],
434+
base_model_method: Union[None, str],
434435
**base_model_prediction_param
435436
) -> pd.core.frame.DataFrame:
436437
"""predict_one_stixel
@@ -439,6 +440,7 @@ def predict_one_stixel(
439440
X_test_stixel (pd.core.frame.DataFrame): Input testing variables
440441
task (str): One of 'regression', 'classification' and 'hurdle'
441442
model_x_names_tuple (tuple[Union[None, BaseEstimator], list]): A tuple of (model, stixel_specific_x_names)
443+
base_model_method (Union[None, str]): The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function.
442444
base_model_prediction_param: Additional parameter passed to base_model.predict_proba or base_model.predict
443445
444446
Returns:
@@ -452,13 +454,25 @@ def predict_one_stixel(
452454
return None
453455

454456
# get test data
455-
if task == "regression":
456-
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]])
457-
else:
458-
pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
459-
pred = pred[:,1]
460-
461-
457+
pred = None
458+
if base_model_method is not None:
459+
if hasattr(model_x_names_tuple[0], base_model_method):
460+
pred_func = getattr(model_x_names_tuple[0], base_model_method)
461+
pred = pred_func(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
462+
else:
463+
if isinstance(model_x_names_tuple[0], dummy_model1):
464+
pass
465+
else:
466+
raise TypeError(f"{base_model_method} does not exists for base model {type(model_x_names_tuple[0])}")
467+
468+
if pred is None:
469+
# Still haven't found the pred function
470+
if task == "regression":
471+
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]])
472+
else:
473+
pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
474+
pred = pred[:,1]
475+
462476
res = pd.DataFrame({"index": list(X_test_stixel.index), "pred": np.array(pred).flatten()}).set_index("index")
463477

464478
return res

tests/make_models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,31 @@ def make_AdaSTEMClassifier_caliP(fold_=2, min_req=1, **kwargs):
365365
min_class_sample=3,
366366
**kwargs
367367
)
368+
return model
369+
370+
371+
def make_AdaSTEMClassifier_custom_pred_method(base_model_class, fold_=2, min_req=1, **kwargs):
372+
373+
model = AdaSTEMClassifier(
374+
base_model=base_model_class(),
375+
save_gridding_plot=True,
376+
ensemble_fold=fold_,
377+
min_ensemble_required=min_req,
378+
grid_len_upper_threshold=50,
379+
grid_len_lower_threshold=20,
380+
temporal_start=1,
381+
temporal_end=366,
382+
temporal_step=40,
383+
temporal_bin_interval=80,
384+
points_lower_threshold=30,
385+
Spatio1="longitude",
386+
Spatio2="latitude",
387+
Temporal1="DOY",
388+
temporal_bin_start_jitter="adaptive",
389+
spatio_bin_jitter_magnitude="adaptive",
390+
use_temporal_to_train=True,
391+
n_jobs=1,
392+
sample_weights_for_classifier=False,
393+
**kwargs
394+
)
368395
return model
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from stemflow.model.AdaSTEM import AdaSTEM
5+
from stemflow.model_selection import ST_train_test_split
6+
from xgboost import XGBClassifier, XGBRegressor
7+
8+
from .make_models import (
9+
make_AdaSTEMClassifier,
10+
make_AdaSTEMClassifier_custom_pred_method
11+
)
12+
from .set_up_data import set_up_data
13+
14+
x_names, (X, y) = set_up_data()
15+
X_train, X_test, y_train, y_test = ST_train_test_split(
16+
X, y, Spatio_blocks_count=100, Temporal_blocks_count=100, random_state=42, test_size=0.3
17+
)
18+
def test_AdaSTEMClassifier():
19+
20+
class my_base_model:
21+
def __init__(self):
22+
self.model = XGBClassifier(tree_method="hist", random_state=42, verbosity=0, n_jobs=1)
23+
pass
24+
def fit(self, X_train, y_train):
25+
self.model.fit(X_train, y_train)
26+
return self
27+
def predict(self, X_test):
28+
return self.model.predict(X_test)
29+
def predict_proba(self, X_test):
30+
return self.model.predict_proba(X_test)
31+
def special_predict(self, X_test):
32+
# Fold change
33+
pred1 = self.model.predict_proba(X_test)[:,1]
34+
pred2 = self.model.predict_proba(X_test + np.random.normal(loc=0, scale=1, size=X_test.shape))[:,1]
35+
# Interaction
36+
i_ = np.log(np.clip(1e-6, 1-1e-6, pred1) / np.clip(1e-6, 1-1e-6, pred2))
37+
pred = i_ # Should be -inf to inf, 0 as no interaction
38+
return pred
39+
40+
model = make_AdaSTEMClassifier_custom_pred_method(base_model_class=my_base_model)
41+
# model = make_AdaSTEMClassifier()
42+
model = model.fit(X_train, np.where(y_train > 0, 1, 0))
43+
44+
pred_mean = model.predict_proba(X_test.reset_index(drop=True), return_std=False, verbosity=1, n_jobs=1, base_model_method='special_predict')[:,1]
45+
pred_mean = pred_mean[~np.isnan(pred_mean)]
46+
47+
# print(pred_mean)
48+
# assert(np.sum((pred_mean < 1000) & (pred_mean > 1001)))
49+

0 commit comments

Comments
 (0)