Skip to content

Commit ba54c1b

Browse files
committed
fix bugs
1 parent 167606b commit ba54c1b

9 files changed

+21
-13
lines changed

docs/Examples/01.AdaSTEM_demo.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@
730730
},
731731
{
732732
"cell_type": "code",
733-
"execution_count": 9,
733+
"execution_count": null,
734734
"metadata": {},
735735
"outputs": [],
736736
"source": [
@@ -740,6 +740,7 @@
740740
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
741741
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
742742
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
743+
" task='hurdle',\n",
743744
" save_gridding_plot = True,\n",
744745
" ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
745746
" min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted\n",

docs/Examples/02.AdaSTEM_learning_curve_analysis.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
},
137137
{
138138
"cell_type": "code",
139-
"execution_count": 8,
139+
"execution_count": null,
140140
"metadata": {},
141141
"outputs": [
142142
{
@@ -195,6 +195,7 @@
195195
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
196196
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
197197
" ),\n",
198+
" task='hurdle',\n",
198199
" save_gridding_plot = True,\n",
199200
" ensemble_fold=10, \n",
200201
" min_ensemble_required=7,\n",

docs/Examples/04.SphereAdaSTEM_demo.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@
721721
},
722722
{
723723
"cell_type": "code",
724-
"execution_count": 9,
724+
"execution_count": null,
725725
"metadata": {},
726726
"outputs": [],
727727
"source": [
@@ -731,6 +731,7 @@
731731
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
732732
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
733733
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
734+
" task='hurdle',\n",
734735
" save_gridding_plot = True,\n",
735736
" ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
736737
" min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted\n",

docs/Examples/05.Hurdle_in_ada_or_ada_in_hurdle.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@
651651
},
652652
{
653653
"cell_type": "code",
654-
"execution_count": 8,
654+
"execution_count": null,
655655
"metadata": {},
656656
"outputs": [],
657657
"source": [
@@ -660,6 +660,7 @@
660660
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
661661
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
662662
" ),\n",
663+
" task='hurdle',\n",
663664
" save_gridding_plot = True,\n",
664665
" ensemble_fold=10, \n",
665666
" min_ensemble_required=7,\n",

docs/Examples/06.Base_model_choices.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@
710710
" classifier=base_model_dict[base_model_name]['classifier'],\n",
711711
" regressor=base_model_dict[base_model_name]['regressor']\n",
712712
" ),\n",
713+
" task='hurdle',\n",
713714
" save_gridding_plot = True,\n",
714715
" ensemble_fold=10, \n",
715716
" min_ensemble_required=7,\n",

docs/Examples/07.Optimizing_stixel_size.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@
581581
},
582582
{
583583
"cell_type": "code",
584-
"execution_count": 10,
584+
"execution_count": null,
585585
"metadata": {},
586586
"outputs": [
587587
{
@@ -1154,6 +1154,7 @@
11541154
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
11551155
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
11561156
" ),\n",
1157+
" task='hurdle',\n",
11571158
" save_gridding_plot = True,\n",
11581159
" ensemble_fold=10, \n",
11591160
" min_ensemble_required=7,\n",

docs/Examples/08.Lazy_loading.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@
678678
},
679679
{
680680
"cell_type": "code",
681-
"execution_count": 9,
681+
"execution_count": null,
682682
"metadata": {},
683683
"outputs": [],
684684
"source": [
@@ -691,6 +691,7 @@
691691
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
692692
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
693693
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
694+
" task='hurdle',\n",
694695
" save_gridding_plot = True,\n",
695696
" ensemble_fold=ensemble_fold, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
696697
" min_ensemble_required=ensemble_fold-2, # Only points covered by > 7 ensembles will be predicted\n",
@@ -720,6 +721,7 @@
720721
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
721722
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
722723
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
724+
" task='hurdle',\n",
723725
" save_gridding_plot = True,\n",
724726
" ensemble_fold=ensemble_fold, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
725727
" min_ensemble_required=ensemble_fold-2, # Only points covered by > 7 ensembles will be predicted\n",

stemflow/model/AdaSTEM.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,9 @@ def eval_STEM_res(
10221022
elif task == "hurdle":
10231023
cls_threshold = 0
10241024

1025-
if not task == "regression":
1025+
if task == "regression":
1026+
auc, kappa, f1, precision, recall, average_precision = [np.nan] * 6
1027+
else:
10261028
a = pd.DataFrame({"y_true": np.array(y_test).flatten(), "pred": np.array(y_pred).flatten()}).dropna()
10271029

10281030
y_test_b = np.where(a.y_true > cls_threshold, 1, 0)
@@ -1039,9 +1041,6 @@ def eval_STEM_res(
10391041
recall = recall_score(y_test_b, y_pred_b)
10401042
average_precision = average_precision_score(y_test_b, y_pred_b)
10411043

1042-
else:
1043-
auc, kappa, f1, precision, recall, average_precision = [np.nan] * 6
1044-
10451044
if not task == "classification":
10461045
a = pd.DataFrame({"y_true": y_test, "pred": y_pred}).dropna()
10471046
s_r, _ = spearmanr(np.array(a.y_true), np.array(a.pred))

stemflow/model/static_func_AdaSTEM.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,12 @@ def predict_one_stixel(
467467

468468
if pred is None:
469469
# Still haven't found the pred function
470-
if task == "regression":
471-
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]])
472-
else:
470+
if task == "classification":
473471
pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
474472
pred = pred[:,1]
473+
else:
474+
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
475+
475476

476477
res = pd.DataFrame({"index": list(X_test_stixel.index), "pred": np.array(pred).flatten()}).set_index("index")
477478

0 commit comments

Comments
 (0)