Skip to content

Commit d060e80

Browse files
committed
add model saving verbosity option
1 parent b5a807a commit d060e80

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

stemflow/model/AdaSTEM.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -954,13 +954,13 @@ def find_belonged_points_and_predict(df, st_indexes_df, X_df):
954954

955955
window_prediction_list.append(res)
956956

957-
if any([i is not None for i in window_prediction_list]):
958-
ensemble_prediction = pd.concat(window_prediction_list, axis=0)
959-
ensemble_prediction = ensemble_prediction.groupby("index").mean().reset_index(drop=False)
960-
else:
961-
ensmeble_index = list(window_single_ensemble_df["ensemble_index"])[0]
962-
warnings.warn(f"No prediction for this ensemble: {ensmeble_index}")
963-
ensemble_prediction = None
957+
if any([i is not None for i in window_prediction_list]):
958+
ensemble_prediction = pd.concat(window_prediction_list, axis=0)
959+
ensemble_prediction = ensemble_prediction.groupby("index").mean().reset_index(drop=False)
960+
else:
961+
ensmeble_index = list(window_single_ensemble_df["ensemble_index"])[0]
962+
warnings.warn(f"No prediction for this ensemble: {ensmeble_index}")
963+
ensemble_prediction = None
964964

965965
return ensemble_prediction
966966

@@ -1501,25 +1501,35 @@ def load(tar_gz_file, new_lazy_loading_path=None, remove_original_file=False):
15011501

15021502
return model
15031503

1504-
def save(self, tar_gz_file, remove_temporary_file = True):
1505-
if not os.path.exists(self.lazy_loading_dir):
1506-
os.makedirs(self.lazy_loading_dir, exist_ok=False)
1504+
def save(self, tar_gz_file, remove_temporary_file = True, verbosity=1, compresslevel=2):
1505+
os.makedirs(self.lazy_loading_dir, exist_ok=True)
15071506

1508-
# temporary save the model using pickle
1509-
model_path = os.path.join(self.lazy_loading_dir, f'model.pkl')
1507+
# dump the main object
1508+
model_path = os.path.join(self.lazy_loading_dir, 'model.pkl')
15101509
with open(model_path, 'wb') as f:
15111510
pickle.dump(self, f)
1512-
1513-
# save the main model class and potentially lazyloading pieces to the tar.gz file
1514-
with tarfile.open(tar_gz_file, "w:gz") as tar:
1515-
for pieces in os.listdir(self.lazy_loading_dir):
1516-
tar.add(os.path.join(self.lazy_loading_dir, pieces), arcname=pieces)
15171511

1518-
if remove_temporary_file:
1519-
if self.lazy_loading_dir is not None:
1520-
if os.path.exists(self.lazy_loading_dir):
1521-
shutil.rmtree(self.lazy_loading_dir)
1512+
# collect files recursively (deterministic order)
1513+
root = os.path.abspath(self.lazy_loading_dir)
1514+
files = []
1515+
for dp, dns, fns in os.walk(root):
1516+
dns.sort(); fns.sort()
1517+
for fn in fns:
1518+
files.append(os.path.join(dp, fn))
1519+
1520+
it = files
1521+
if verbosity > 0 and tqdm is not None:
1522+
it = tqdm(files, desc="Archiving", unit="file")
15221523

1524+
with tarfile.open(tar_gz_file, "w:gz", compresslevel=compresslevel) as tar:
1525+
for fp in it:
1526+
arcname = os.path.relpath(fp, root)
1527+
tar.add(fp, arcname=arcname, recursive=False)
1528+
1529+
if remove_temporary_file and os.path.exists(self.lazy_loading_dir):
1530+
shutil.rmtree(self.lazy_loading_dir)
1531+
1532+
15231533
@staticmethod
15241534
def _cleanup(lazy_loading_dir):
15251535
if lazy_loading_dir is not None:

0 commit comments

Comments
 (0)