Skip to content

Commit 6baf66a

Browse files
committed
update joblib_temp_folder setting
1 parent 035c53c commit 6baf66a

File tree

3 files changed

+77
-15
lines changed

3 files changed

+77
-15
lines changed

stemflow/model/AdaSTEM.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def __init__(
114114
lazy_loading_dir: Union[str, None] = None,
115115
min_class_sample: int = 1,
116116
ensemble_bootstrap: bool = False,
117-
joblib_backend: str = 'loky'
117+
joblib_backend: str = 'loky',
118+
joblib_temp_folder: Union[None, str] = None
118119
):
119120
"""Make an AdaSTEM object
120121
@@ -192,6 +193,8 @@ def __init__(
192193
Whether to bootstrap the data at each ensemble level to account for uncertainty. Defaults to False.
193194
joblib_backend:
194195
The backend of joblib. Defaults to 'loky'. Other options include 'multiprocessing', 'threading'.
196+
joblib_temp_folder:
197+
The temporary folder for joblib. If None, falling back to joblib's default directory. If 'lazy_loading_dir', set as the same directory as lazy_loading_dir. If it's string, create a directory and store data into it. Defaults to None.
195198
Raises:
196199
AttributeError: Base model do not have method 'fit' or 'predict'
197200
AttributeError: task not in one of ['regression', 'classification', 'hurdle']
@@ -267,6 +270,7 @@ def __init__(
267270
n_jobs = check_transform_n_jobs(self, n_jobs)
268271
self.n_jobs = n_jobs
269272
self.joblib_backend = joblib_backend
273+
self.joblib_temp_folder = joblib_temp_folder
270274

271275
# 7. Plotting params
272276
self.plot_xlims = plot_xlims
@@ -374,7 +378,7 @@ def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] =
374378
)
375379

376380
if n_jobs > 1 and isinstance(n_jobs, int):
377-
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.lazy_loading_dir)
381+
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.joblib_temp_folder)
378382
output_generator = parallel(
379383
joblib.delayed(partial_get_one_ensemble_quadtree)(
380384
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
@@ -573,7 +577,7 @@ def mp_train(ensemble, self=self, data=data):
573577
res = self.SAC_ensemble_training(index_df=ensemble[1], data=data)
574578
return res
575579

576-
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.lazy_loading_dir)
580+
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.joblib_temp_folder)
577581
output_generator = parallel(joblib.delayed(mp_train)(i) for i in groups)
578582

579583
# tqdm wrapper
@@ -642,6 +646,15 @@ def fit(
642646
shutil.rmtree(self.lazy_loading_dir)
643647
self.lazy_loading_dir = str(Path(self.lazy_loading_dir.rstrip('/\\')))
644648

649+
# Setup joblib_temp_folder
650+
if self.joblib_temp_folder is None:
651+
pass
652+
elif self.joblib_temp_folder=='lazy_loading_dir':
653+
self.joblib_temp_folder = self.lazy_loading_dir
654+
else:
655+
if not os.path.exists(self.joblib_temp_folder):
656+
os.makedirs(self.joblib_temp_folder)
657+
645658
verbosity = check_verbosity(self, verbosity)
646659
check_X_train(X_train)
647660
check_y_train(y_train)
@@ -804,7 +817,7 @@ def mp_predict(ensemble, self=self, data=data):
804817
res = self.SAC_ensemble_predict(index_df=ensemble[1], data=data)
805818
return res
806819

807-
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.lazy_loading_dir)
820+
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.joblib_temp_folder)
808821
output_generator = parallel(joblib.delayed(mp_predict)(i) for i in groups)
809822

810823
# tqdm wrapper
@@ -1224,7 +1237,7 @@ def assign_feature_importances_by_points(
12241237

12251238
# assign input spatio-temporal points to stixels
12261239
if n_jobs > 1:
1227-
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.lazy_loading_dir)
1240+
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.joblib_temp_folder)
12281241
output_generator = parallel(joblib.delayed(partial_assign_func)(i) for i in list(range(self.ensemble_fold)))
12291242
if verbosity > 0:
12301243
output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Querying ensembles: ")
@@ -1381,7 +1394,8 @@ def __init__(
13811394
lazy_loading_dir = None,
13821395
min_class_sample = 1,
13831396
ensemble_bootstrap = False,
1384-
joblib_backend = 'loky'
1397+
joblib_backend = 'loky',
1398+
joblib_temp_folder = None
13851399
):
13861400
super().__init__(
13871401
base_model=base_model,
@@ -1416,7 +1430,8 @@ def __init__(
14161430
lazy_loading_dir=lazy_loading_dir,
14171431
min_class_sample=min_class_sample,
14181432
ensemble_bootstrap=ensemble_bootstrap,
1419-
joblib_backend=joblib_backend
1433+
joblib_backend=joblib_backend,
1434+
joblib_temp_folder = joblib_temp_folder
14201435
)
14211436

14221437
self._estimator_type = 'classifier'
@@ -1569,7 +1584,8 @@ def __init__(
15691584
lazy_loading_dir=None,
15701585
min_class_sample=1,
15711586
ensemble_bootstrap=False,
1572-
joblib_backend='loky'
1587+
joblib_backend='loky',
1588+
joblib_temp_folder=None
15731589
):
15741590
super().__init__(
15751591
base_model=base_model,
@@ -1604,7 +1620,8 @@ def __init__(
16041620
lazy_loading_dir=lazy_loading_dir,
16051621
min_class_sample=min_class_sample,
16061622
ensemble_bootstrap=ensemble_bootstrap,
1607-
joblib_backend=joblib_backend
1623+
joblib_backend=joblib_backend,
1624+
joblib_temp_folder=joblib_temp_folder
16081625
)
16091626

16101627
self._estimator_type = 'regressor'

stemflow/model/STEM.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __init__(
5151
lazy_loading_dir: Union[str, None] = None,
5252
min_class_sample: int = 1,
5353
ensemble_bootstrap: bool = False,
54-
joblib_backend: str = 'loky'
54+
joblib_backend: str = 'loky',
55+
joblib_temp_folder: Union[None, str] = None
5556
):
5657
"""Make a STEM object
5758
@@ -127,6 +128,8 @@ def __init__(
127128
Whether to bootstrap the data at each ensemble level to account for uncertainty. Defaults to False.
128129
joblib_backend:
129130
The backend of joblib. Defaults to 'loky'. Other options include 'multiprocessing', 'threading'.
131+
joblib_temp_folder:
132+
The temporary folder for joblib. If None, falling back to joblib's default directory. If 'lazy_loading_dir', set as the same directory as lazy_loading_dir. If it's string, create a directory and store data into it. Defaults to None.
130133
Raises:
131134
AttributeError: Base model do not have method 'fit' or 'predict'
132135
AttributeError: task not in one of ['regression', 'classification', 'hurdle']
@@ -186,7 +189,8 @@ def __init__(
186189
lazy_loading_dir=lazy_loading_dir,
187190
min_class_sample=min_class_sample,
188191
ensemble_bootstrap=ensemble_bootstrap,
189-
joblib_backend=joblib_backend
192+
joblib_backend=joblib_backend,
193+
joblib_temp_folder=joblib_temp_folder
190194
)
191195

192196
self.grid_len = grid_len
@@ -254,7 +258,8 @@ def __init__(
254258
lazy_loading_dir: Union[str, None] = None,
255259
min_class_sample: int = 1,
256260
ensemble_bootstrap: bool = False,
257-
joblib_backend: str = 'loky'
261+
joblib_backend: str = 'loky',
262+
joblib_temp_folder: Union[None, str] = None
258263
):
259264
super().__init__(
260265
base_model=base_model,
@@ -289,7 +294,8 @@ def __init__(
289294
lazy_loading_dir=lazy_loading_dir,
290295
min_class_sample=min_class_sample,
291296
ensemble_bootstrap=ensemble_bootstrap,
292-
joblib_backend=joblib_backend
297+
joblib_backend=joblib_backend,
298+
joblib_temp_folder=joblib_temp_folder
293299
)
294300

295301
self.grid_len = grid_len
@@ -357,7 +363,8 @@ def __init__(
357363
lazy_loading_dir: Union[str, None] = None,
358364
min_class_sample: int = 1,
359365
ensemble_bootstrap: bool = False,
360-
joblib_backend: str = 'loky'
366+
joblib_backend: str = 'loky',
367+
joblib_temp_folder: Union[None, str]= None
361368
):
362369
super().__init__(
363370
base_model=base_model,
@@ -392,7 +399,8 @@ def __init__(
392399
lazy_loading_dir=lazy_loading_dir,
393400
min_class_sample=min_class_sample,
394401
ensemble_bootstrap=ensemble_bootstrap,
395-
joblib_backend=joblib_backend
402+
joblib_backend=joblib_backend,
403+
joblib_temp_folder=joblib_temp_folder
396404
)
397405

398406
self.grid_len = grid_len

tests/test_joblib_temp_folder.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
import pandas as pd
3+
import os
4+
5+
from stemflow.model.AdaSTEM import AdaSTEM
6+
from stemflow.model_selection import ST_train_test_split
7+
8+
from .make_models import (
9+
make_AdaSTEMClassifier,
10+
make_AdaSTEMRegressor,
11+
make_parallel_SphereAdaClassifier,
12+
make_parallel_STEMClassifier,
13+
make_SphereAdaClassifier,
14+
make_SphereAdaSTEMRegressor,
15+
make_STEMClassifier,
16+
make_STEMRegressor,
17+
)
18+
from .set_up_data import set_up_data
19+
20+
x_names, (X, y) = set_up_data()
21+
X_train, X_test, y_train, y_test = ST_train_test_split(
22+
X, y, Spatio_blocks_count=100, Temporal_blocks_count=100, random_state=42, test_size=0.3
23+
)
24+
25+
26+
def test_AdaSTEMRegressor_custom_temp_folder1():
27+
model = make_AdaSTEMRegressor(lazy_loading=True, joblib_temp_folder='lazy_loading_dir')
28+
model = model.fit(X_train, np.where(y_train > 0, 1, 0))
29+
30+
pred_mean, pred_std = model.predict(X_test.reset_index(drop=True), return_std=True, verbosity=1, n_jobs=1)
31+
32+
def test_AdaSTEMRegressor_custom_temp_folder2():
33+
model = make_AdaSTEMRegressor(lazy_loading=True, joblib_temp_folder='./test_tmp_folder')
34+
model = model.fit(X_train, np.where(y_train > 0, 1, 0))
35+
assert os.path.exists('./test_tmp_folder')
36+
37+
pred_mean, pred_std = model.predict(X_test.reset_index(drop=True), return_std=True, verbosity=1, n_jobs=1)

0 commit comments

Comments
 (0)