diff --git a/NEWS.md b/NEWS.md index d93cb77..5398894 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,30 @@ +------- +stemflow version 1.1.6 +------- +**Oct, 2025** + +Fixed several issues. Fix prediction bug, lazyloading bug; update plotting function; update docs. #82. Also: A previous bug: after getting an attribute of a LazyLoadingEstimator object, the model was not auto-dumped. This is now fixed. + + +------- +stemflow version 1.1.5 +------- +**Oct, 2025** + +This is a large update + +Features: +1. The major changes are that the `AdaSTEM` class now supports `duckdb` and `parquet` file path as input, this allow the user to pass in large dataset without duplicating the pandas dataframe cross the processors when working with n_jobs>1 parallel computing. See the new Jupyter notebooks for details. #76 +2. The lazy loading is no longer realized by the`LazyLoadingEnsemble` class. Instead, it is realized by `LazyLoadingEstimator`. This allow the model to be dumped once its training/prediction is finished, and we don't need to accumulate the models (hence, memory) until the training is finished for the whole ensemble. This will largely reduce the memory use. See the new Jupyter notebooks for details. #77 +3. n_jobs > ensemble_folds are no longer supported for user-end clarity. Those jobs are paralleled by ensemble folds so n_jobs > ensemble_folds is meaning less. We do not want to mislead users to think that a 10-ensemble model will be trained faster using n_jobs=20 compared to n_jobs=10. +4. These features will not be available in `SphereAdaSTEM` due to the negligible user market and the negligible advantages. #75 + +Major bugs fixed: +1. Previously the models are stored in `self.model_dict` dynamically during the parallel ensemble training process, which means the dictionary is being altered during this process. However, we ask for a `self` as input argument for the ensemble-level training function serialization. This is not ideal since the object being serialized should not be changing. This is fixed by assigning the `model_dict` to `self` after all trainings are finished. +2. Also fixed #74 + + + ------- stemflow version 1.1.3 ------- diff --git a/docs/Examples/08.Lazy_loading.ipynb b/docs/Examples/08.Lazy_loading.ipynb index c88360f..5622f57 100644 --- a/docs/Examples/08.Lazy_loading.ipynb +++ b/docs/Examples/08.Lazy_loading.ipynb @@ -1560,7 +1560,7 @@ "metadata": {}, "source": [ "From the results we can clearly see the trade-off. Using lazy-loading:\n", - "1. It is so interesting that lazy-loading seems to even reduce the prediction time... Maybe because it does not load unnecessary models an only focus on certain stixels that cover the needed points.\n", + "1. It is so interesting that lazy-loading seems to even reduce the prediction time. This could be due to two reasons: (1) A lazy-loading model does not load unnecessary models an only focus on certain stixels that cover the needed points, and (2) joblib does not need to serialize a huge amount of data (models) which saves so much time.\n", "2. Has large impact on testing (prediction) speed. The time for prediction is more than doubled in our case.\n", "3. Lazy-loading will maintain memory-use stable and unchanged as ensemble fold increases (maintaining ~ 3GB in our case), while non-lazy-loading will have linear memory consumption growth." ] @@ -2200,7 +2200,7 @@ "Still, the memory use will proportionally increase when n_jobs increase. That is because\n", "1. Your data is being copied n_jobs times -- once for each processor, because data cannot be shared among processors. This problem cannot be solved by lazy loading, but can be solved by using database query (see the other notebook for how to use duckdb as input).\n", "2. The trained models also cost memory. For non-lazy loading, all trained models are saved in memory, so a 10-ensemble model means 10 times more models, therefore memory, than a 1-ensemble model. Despite that, lazy-loading still managed to reduce this memory load by only allowing ~1 models in memory per ensemble (so still proportional to the number of ensembles), and ask that if the model has finished training or predicting, auto-dump itself to disk.\n", - "3. It is still surprising that prediction is so much faster when using lazy loading..." + "3. Lazy-loading also seems to dramatically reduce the prediction time. This means that avoiding serializing huge amount of data (models) with joblib is more important than I/O overhead in single model reading/dumping." ] }, { diff --git a/requirements.txt b/requirements.txt index bb4e14a..4a837f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ scipy>=1.10.1 setuptools>=68.2.0 tqdm>=4.65.0 duckdb>=1.1.3 -pyarrow>=17.0.0 \ No newline at end of file +pyarrow>=17.0.0 +cartopy>=0.22 diff --git a/stemflow/lazyloading/lazyloading.py b/stemflow/lazyloading/lazyloading.py index f241a44..4f944cb 100644 --- a/stemflow/lazyloading/lazyloading.py +++ b/stemflow/lazyloading/lazyloading.py @@ -41,7 +41,7 @@ def __init__( estimator: Optional[BaseEstimator], dump_dir: Optional[Path | str] = None, filename: Optional[str] = None, - compress: Any = 3, + compress: Any = 0, auto_load: bool = True, auto_dump: bool = False, keep_loaded: bool = False, @@ -91,14 +91,16 @@ def __getattr__(self, name): # Try autoloading and then delegate if name.startswith("__"): # avoid dunder recursion raise AttributeError(name) - with self._lock: - if self.estimator is None and self.auto_load: - self._load_inplace() - if self.estimator is not None and hasattr(self.estimator, name): - return getattr(self.estimator, name) - # Fallback to default behavior - raise AttributeError(f"{type(self).__name__} has no attribute '{name}'") + if self.estimator is None and not self.auto_load: + raise AttributeError(f"Trying to get a attribute of estimator, but the estimator can not be auto-loaded from the disk because auto_load=False.") + + with self._loaded_estimator() as est: + if hasattr(est, name): + return getattr(est, name) + else: + raise AttributeError(f"{type(est).__name__} has no attribute '{name}'") + # ---------- Persistence helpers ---------- def _resolve_path(self) -> Path: if self.dump_dir is None: @@ -133,6 +135,7 @@ def dump(self) -> Path: shutil.move(str(tmp_path), str(path)) # Free memory self.estimator = None + finally: # Best-effort cleanup try: @@ -143,8 +146,10 @@ def dump(self) -> Path: tmp_dir.rmdir() except Exception: pass + return path + def load(self, path: Optional[Path | str] = None) -> "LazyLoadingEstimator": """ Load the inner estimator from disk into this wrapper (in-place). @@ -159,7 +164,7 @@ def load_from_dir( cls, dump_dir: Path | str, filename: Optional[str] = None, - compress: Any = 3, + compress: Any = 0, **kwargs, ) -> "LazyLoadingEstimator": dump_dir = Path(dump_dir) diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py index 390ca52..215dc05 100644 --- a/stemflow/model/AdaSTEM.py +++ b/stemflow/model/AdaSTEM.py @@ -47,7 +47,7 @@ from ..utils.quadtree import get_one_ensemble_quadtree from ..utils.validation import ( check_base_model, - check_prediciton_aggregation, + check_prediction_aggregation, check_prediction_return, check_random_state, check_spatial_scale, @@ -145,7 +145,7 @@ def __init__( Overriden by grid_len_*_upper_threshold parameters. Defaults to 50. stixel_training_size_threshold: Do not train the model if the available data records for this stixel is less than this threshold, - and directly set the value to np.nan. Defaults to 50. + and directly set the value to np.nan. Defaults to None, which will be set as the same value as `points_lower_threshold`. temporal_start: start of the temporal sequence. Defaults to 1. temporal_end: @@ -200,7 +200,7 @@ def __init__( joblib_backend: The backend of joblib. Defaults to 'loky'. Other options include 'threading'. ('multiprocessing' not supported because it does not allow generator format). max_mem: - The maximum memory use during the training or prediction process. Should be format like '60GB', '512MB', '1.5GB'. + The maximum memory use during the training or prediction process. Should be format like '60GB', '512MB', '1.5GB'. This argument is only valid if your input training/prediction data is the path to duckdb files. But even if your input are duckdb files, this argument does not guarantee that the memory use for each joblib worker will be lower than that due to various intermediate objects. This argument should only be considered a relative constrain. Raises: AttributeError: Base model do not have method 'fit' or 'predict' AttributeError: task not in one of ['regression', 'classification', 'hurdle'] @@ -500,12 +500,12 @@ def stixel_fitting(self, stixel): raise AttributeError('__index_level_0__ should not apprear in the final training data!') unique_stixel_id = stixel["unique_stixel_id"].iloc[0] - name = unique_stixel_id + ensemble_index = unique_stixel_id.split('_')[0] if self.lazy_loading: base_model = LazyLoadingEstimator(estimator=self.base_model, - dump_dir=os.path.join(self.lazy_loading_dir, 'models', 'ensemble_' + name.split('_')[1]), - filename=f"model_{name}.pkl", + dump_dir=os.path.join(self.lazy_loading_dir, 'models', 'ensemble_' + ensemble_index), + filename=f"model_{unique_stixel_id}.pkl", auto_dump=True, auto_load=True, keep_loaded=False) else: base_model = self.base_model @@ -525,7 +525,8 @@ def stixel_fitting(self, stixel): # print(f'Fitting: {ensemble_index}. Not pass: {status}') pass else: - return (name, model, stixel_specific_x_names) + return [unique_stixel_id, model, stixel_specific_x_names] + def SAC_ensemble_training(self, single_ensemble_df: pd.DataFrame, X_train: Union[pd.DataFrame, str], y_train: Union[pd.DataFrame, str], temporal_window_prequery: bool = False): @@ -574,9 +575,9 @@ def SAC_ensemble_training(self, single_ensemble_df: pd.DataFrame, X_train: Union ]) # Apply bootstrap temporal_window_indexes = bootstrap_indices[np.isin(bootstrap_indices, temporal_window_indexes)] if self.ensemble_bootstrap else temporal_window_indexes + window_X_df_indexes_only = X_train_df.loc[temporal_window_indexes][[self.Temporal1, self.Spatio1, self.Spatio2]] window_X_df = X_train_df.loc[temporal_window_indexes] if temporal_window_prequery else X_train_df window_y_df = y_train_df.loc[temporal_window_indexes] if temporal_window_prequery else y_train_df - window_X_df_indexes_only = X_train_df.loc[temporal_window_indexes][[self.Temporal1, self.Spatio1, self.Spatio2]] else: temporal_window_indexes = con.sql(f"SELECT __index_level_0__ FROM X_train_df WHERE {self.Temporal1} >= {start} AND {self.Temporal1} < {start + self.temporal_bin_interval};").df().values.flatten() # Apply bootstrap @@ -649,10 +650,13 @@ def find_belonged_points(df, st_indexes_df, X_df, y_df): def find_belonged_points_and_fit(df, st_indexes_df, X_df, y_df): X_y = find_belonged_points(df, st_indexes_df, X_df, y_df) + if len(X_y)==0: + return None X_y['ensemble_index'] = df['ensemble_index'].iloc[0] X_y['unique_stixel_id'] = df['unique_stixel_id'].iloc[0] X_y = X_y.sort_index() # To ensure the input dataframes for the two method (temporal_window_prequery or not) are the same so the tained base models are identical, at least with the same input data - return self.stixel_fitting(X_y) + fitted_res = self.stixel_fitting(X_y) + return fitted_res # train res = ( @@ -812,8 +816,10 @@ def stixel_predict(self, stixel: pd.DataFrame) -> Union[None, pd.DataFrame]: """ if '__index_level_0__' in stixel.columns: raise AttributeError('__index_level_0__ should not apprear in the final training data!') + + if len(stixel) == 0: + return None # No data to predict - stixel['unique_stixel_id'] = stixel.name unique_stixel_id = stixel["unique_stixel_id"].iloc[0] model_x_names_tuple = get_model_and_stixel_specific_x_names( @@ -838,7 +844,7 @@ def stixel_predict(self, stixel: pd.DataFrame) -> Union[None, pd.DataFrame]: return pred def SAC_ensemble_predict( - self, single_ensemble_df: pd.DataFrame, data: Union[pd.DataFrame, str] + self, single_ensemble_df: pd.DataFrame, data: Union[pd.DataFrame, str], temporal_window_prequery: bool = False ) -> pd.DataFrame: """A sub-module of SAC prediction function. Predict only one ensemble. @@ -846,6 +852,7 @@ def SAC_ensemble_predict( Args: single_ensemble_df (pd.DataFrame): ensemble data (model.ensemble_df) data (pd.DataFrame, str): input covariates to predict + temporal_window_prequery (bool): Whether to prequery the temporal windows as pd.DataFrame object to speed-up the stixel query. If set to True, query speed will be faster but with a moderate memory usage increase. Returns: pd.DataFrame: Prediction result of one ensemble. """ @@ -868,10 +875,17 @@ def SAC_ensemble_predict( (data_df[self.Temporal1] < start + self.temporal_bin_interval) ]) window_X_df_indexes_only = data_df.loc[temporal_window_indexes][[self.Temporal1, self.Spatio1, self.Spatio2]] + window_X_df = data_df.loc[temporal_window_indexes] if temporal_window_prequery else data_df else: temporal_window_indexes = con.sql(f"SELECT __index_level_0__ FROM data_df WHERE {self.Temporal1} >= {start} AND {self.Temporal1} < {start + self.temporal_bin_interval};").df().values.flatten() temporal_window_indexes_df = pd.DataFrame(temporal_window_indexes, columns=['__index_level_0__']) con.register("temporal_window_indexes_df", temporal_window_indexes_df) + window_X_df = con.sql(f""" + SELECT temporal_window_indexes_df.__index_level_0__, data_df.* EXCLUDE(__index_level_0__) + FROM data_df + JOIN temporal_window_indexes_df + ON data_df.__index_level_0__ = temporal_window_indexes_df.__index_level_0__; + """).df().set_index('__index_level_0__') if temporal_window_prequery else data_df window_X_df_indexes_only = con.sql(f""" SELECT data_df.{self.Temporal1}, data_df.{self.Spatio1}, data_df.{self.Spatio2}, data_df.__index_level_0__ FROM data_df @@ -909,7 +923,17 @@ def find_belonged_points(df, st_indexes_df, X_df): return X - query_results = ( + def find_belonged_points_and_predict(df, st_indexes_df, X_df): + X = find_belonged_points(df, st_indexes_df, X_df) + if len(X)==0: + return None + X['ensemble_index'] = df['ensemble_index'].iloc[0] + X['unique_stixel_id'] = df['unique_stixel_id'].iloc[0] + # X = X.sort_index() # To ensure the input dataframes for the two method (temporal_window_prequery or not) are the same so the tained base models are identical, at least with the same input data + pred = self.stixel_predict(X) + return pred + + res = ( window_single_ensemble_df[ [ "ensemble_index", @@ -920,40 +944,29 @@ def find_belonged_points(df, st_indexes_df, X_df): "stixel_calibration_point_transformed_upper_bound", ] ] - .groupby(["ensemble_index", "unique_stixel_id"], as_index=True) - .pipe(lambda x: x[x.obj.columns]) # Explicitly select all the columns in the original df to include. To overcome the include_groups=True deprecation warning - .apply(find_belonged_points, st_indexes_df=window_X_df_indexes_only, X_df=data_df, include_groups=False) # although ["ensemble_index", "unique_stixel_id"] will be passed into `find_belonged_points` due to `.pipe(lambda x: x[x.obj.columns])`, the output will not have them so we still set `as_index=True` in `groupby` - .reset_index(level=["ensemble_index", "unique_stixel_id"]) # Turn these indexes into columns and keep the original df indexing - ) - - if len(query_results) == 0: - """All points fall out of the grids""" - continue - - # predict - window_prediction = ( - query_results - .dropna(subset="unique_stixel_id") - .groupby("unique_stixel_id", as_index=False) + .groupby(["ensemble_index", "unique_stixel_id"], as_index=False) .pipe(lambda x: x[x.obj.columns]) # Explicitly select all the columns in the original df to include. To overcome the include_groups=True deprecation warning - .apply(lambda stixel: self.stixel_predict(stixel), include_groups=False) # - .droplevel(0) # If using as_index=False duing groupby, pandas will automatically generate a group indexing column, so drop the indexing of the new groups + .apply(find_belonged_points_and_predict, st_indexes_df=window_X_df_indexes_only, X_df=window_X_df, include_groups=False) # although ["ensemble_index", "unique_stixel_id"] will be passed into `find_belonged_points` due to `.pipe(lambda x: x[x.obj.columns])`, the output will not have them so we still set `as_index=True` in `groupby` ) - window_prediction_list.append(window_prediction) + if len(res)>0: + res = res.droplevel(0) # If using as_index=False duing groupby, pandas will automatically generate a group indexing column, so drop the indexing of the new groups + + window_prediction_list.append(res) - if any([i is not None for i in window_prediction_list]): - ensemble_prediction = pd.concat(window_prediction_list, axis=0) - ensemble_prediction = ensemble_prediction.groupby("index").mean().reset_index(drop=False) - else: - ensmeble_index = list(window_single_ensemble_df["ensemble_index"])[0] - warnings.warn(f"No prediction for this ensemble: {ensmeble_index}") - ensemble_prediction = None + if any([i is not None for i in window_prediction_list]): + ensemble_prediction = pd.concat(window_prediction_list, axis=0) + ensemble_prediction = ensemble_prediction.groupby("index").mean().reset_index(drop=False) + else: + ensmeble_index = list(window_single_ensemble_df["ensemble_index"])[0] + warnings.warn(f"No prediction for this ensemble: {ensmeble_index}") + ensemble_prediction = None return ensemble_prediction def SAC_predict( - self, ensemble_df: pd.DataFrame, data: Union[pd.DataFrame, str], verbosity: int = 0, n_jobs: int = 1 + self, ensemble_df: pd.DataFrame, data: Union[pd.DataFrame, str], verbosity: int = 0, n_jobs: int = 1, temporal_window_prequery: bool = False + ) -> pd.DataFrame: """This function is a prediction function with SAC strategy: Split (S), Apply(A), Combine (C). At ensemble level. @@ -964,6 +977,7 @@ def SAC_predict( data (pd.DataFrame, str): data verbosity (int, optional): Defaults to 0. n_jobs (int): number of processors for parallel computing + temporal_window_prequery (bool): Whether to prequery the temporal windows as pd.DataFrame object to speed-up the stixel query. If set to True, query speed will be faster but with a moderate memory usage increase. Returns: pd.DataFrame: prediction results. @@ -974,10 +988,10 @@ def SAC_predict( # Parallel maker if n_jobs == 1: - output_generator = (self.SAC_ensemble_predict(single_ensemble_df=ensemble[1], data=data) for ensemble in groups) + output_generator = (self.SAC_ensemble_predict(single_ensemble_df=ensemble[1], data=data, temporal_window_prequery=temporal_window_prequery) for ensemble in groups) else: def mp_predict(ensemble, self=self): - res = self.SAC_ensemble_predict(single_ensemble_df=ensemble[1], data=data) + res = self.SAC_ensemble_predict(single_ensemble_df=ensemble[1], data=data, temporal_window_prequery=temporal_window_prequery) return res parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator", backend=self.joblib_backend, temp_folder=self.joblib_tmp_dir) @@ -1013,6 +1027,7 @@ def predict_proba( return_by_separate_ensembles: bool = False, logit_agg: bool = False, base_model_method: Union[None, str] = None, + temporal_window_prequery: bool = False, **base_model_prediction_param ) -> Union[np.ndarray, Tuple[np.ndarray]]: """Predict probability @@ -1038,6 +1053,8 @@ def predict_proba( Whether to use logit aggregation for the classification task. Most likely only used when you are predicting "real" calibrated probability. If True, the model is averaging the probability prediction estimated by all ensembles in logit scale, and then back-tranforms it to probability scale. It's recommended to be jointly used with the CalibratedClassifierCV class in sklearn as a wrapper of the classifier to estimate the calibrated probability. Default is False, but can be set to true for "real" probability averaging. base_model_method: The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function. Notice that dummy model will still predict 0, so the ensemble-aggregated result is still an average of zeros and your special prediction function output. Therefore, it may only make sense if your special prediction function predicts 0 as the absense/control value. Defaults to None. + temporal_window_prequery: + Whether to prequery the temporal windows as pd.DataFrame object to speed-up the stixel query. If set to True, query speed will be faster but with a moderate memory usage increase. base_model_prediction_param: Any other paramters to pass into the prediction method of the base models. e.g., base_model_prediction_param={'n_jobs':1} (set n_jobs=1 for the *base model*). Raises: @@ -1054,7 +1071,7 @@ def predict_proba( """ check_X_test(X_test, self) - check_prediciton_aggregation(aggregation) + check_prediction_aggregation(aggregation) return_by_separate_ensembles, return_std = check_prediction_return(return_by_separate_ensembles, return_std) verbosity = check_verbosity(self, verbosity) n_jobs = check_transform_n_jobs(self, n_jobs) @@ -1067,7 +1084,7 @@ def predict_proba( try: # predict - res = self.SAC_predict(self.ensemble_df, X_test, verbosity=verbosity, n_jobs=n_jobs) + res = self.SAC_predict(self.ensemble_df, X_test, verbosity=verbosity, n_jobs=n_jobs, temporal_window_prequery=temporal_window_prequery) except: # Remove the entire lazy_loading_dir since it includes failed models raise finally: # Remove the joblib_tmp_dir anyway @@ -1159,6 +1176,7 @@ def predict( return_by_separate_ensembles: bool = False, logit_agg: bool = False, base_model_method: Union[None, str] = None, + temporal_window_prequery: bool = False, **base_model_prediction_param ) -> Union[np.ndarray, Tuple[np.ndarray]]: pass @@ -1381,7 +1399,7 @@ def assign_feature_importances_by_points( # verbosity = check_verbosity(self, verbosity=verbosity) n_jobs = check_transform_n_jobs(self, n_jobs) - check_prediciton_aggregation(aggregation) + check_prediction_aggregation(aggregation) # if "feature_importances_" not in dir(self): @@ -1476,32 +1494,42 @@ def load(tar_gz_file, new_lazy_loading_path=None, remove_original_file=False): if model.lazy_loading: for model_name in model.model_dict: if isinstance(model.model_dict[model_name], LazyLoadingEstimator): - model.model_dict[model_name].dump_dir = Path(os.path.join(new_lazy_loading_path, 'models', 'ensemble_' + model_name.split('_')[1])) + model.model_dict[model_name].dump_dir = Path(os.path.join(new_lazy_loading_path, 'models', 'ensemble_' + model_name.split('_')[0])) if remove_original_file: os.remove(tar_gz_file) return model - def save(self, tar_gz_file, remove_temporary_file = True): - if not os.path.exists(self.lazy_loading_dir): - os.makedirs(self.lazy_loading_dir, exist_ok=False) + def save(self, tar_gz_file, remove_temporary_file = True, verbosity=1, compresslevel=2): + os.makedirs(self.lazy_loading_dir, exist_ok=True) - # temporary save the model using pickle - model_path = os.path.join(self.lazy_loading_dir, f'model.pkl') + # dump the main object + model_path = os.path.join(self.lazy_loading_dir, 'model.pkl') with open(model_path, 'wb') as f: pickle.dump(self, f) - - # save the main model class and potentially lazyloading pieces to the tar.gz file - with tarfile.open(tar_gz_file, "w:gz") as tar: - for pieces in os.listdir(self.lazy_loading_dir): - tar.add(os.path.join(self.lazy_loading_dir, pieces), arcname=pieces) - if remove_temporary_file: - if self.lazy_loading_dir is not None: - if os.path.exists(self.lazy_loading_dir): - shutil.rmtree(self.lazy_loading_dir) + # collect files recursively (deterministic order) + root = os.path.abspath(self.lazy_loading_dir) + files = [] + for dp, dns, fns in os.walk(root): + dns.sort(); fns.sort() + for fn in fns: + files.append(os.path.join(dp, fn)) + + it = files + if verbosity > 0 and tqdm is not None: + it = tqdm(files, desc="Archiving", unit="file") + with tarfile.open(tar_gz_file, "w:gz", compresslevel=compresslevel) as tar: + for fp in it: + arcname = os.path.relpath(fp, root) + tar.add(fp, arcname=arcname, recursive=False) + + if remove_temporary_file and os.path.exists(self.lazy_loading_dir): + shutil.rmtree(self.lazy_loading_dir) + + @staticmethod def _cleanup(lazy_loading_dir): if lazy_loading_dir is not None: @@ -1619,6 +1647,7 @@ def predict( return_by_separate_ensembles: bool = False, logit_agg: bool = False, base_model_method: Union[None, str] = None, + temporal_window_prequery: bool = False, **base_model_prediction_param ) -> Union[np.ndarray, Tuple[np.ndarray]]: """A rewrite of predict_proba adapted for Classifier @@ -1646,6 +1675,8 @@ def predict( Whether to use logit aggregation for the classification task. If True, the model is averaging the probability prediction estimated by all ensembles in logit scale, and then back-tranform it to probability scale. It's recommened to be combinedly used with the CalibratedClassifierCV class in sklearn as a wrapper of the classifier to estimate the calibrated probability. base_model_method: The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function. Defaults to None. + temporal_window_prequery: + Whether to prequery the temporal windows as pd.DataFrame object to speed-up the stixel query. If set to True, query speed will be faster but with a moderate memory usage increase. base_model_prediction_param: Any other paramters to pass into the prediction method of the base models. e.g., base_model_prediction_param={'n_jobs':1}. Raises: @@ -1671,6 +1702,7 @@ def predict( return_by_separate_ensembles=return_by_separate_ensembles, logit_agg=logit_agg, base_model_method=base_model_method, + temporal_window_prequery=temporal_window_prequery, **base_model_prediction_param ) mean = mean[:,1].flatten() @@ -1688,6 +1720,7 @@ def predict( return_by_separate_ensembles=return_by_separate_ensembles, logit_agg=logit_agg, base_model_method=base_model_method, + temporal_window_prequery=temporal_window_prequery, **base_model_prediction_param ) mean = mean[:,1].flatten() @@ -1806,6 +1839,7 @@ def predict( aggregation: str = "mean", return_by_separate_ensembles: bool = False, base_model_method: Union[None, str] = None, + temporal_window_prequery: bool = False, **base_model_prediction_param ) -> Union[np.ndarray, Tuple[np.ndarray]]: """A rewrite of predict_proba @@ -1827,6 +1861,8 @@ def predict( Experimental function. return not by aggregation, but by separate ensembles. base_model_method: The name of the prediction method for base models. If None, `predict` or `predict_proba` will be used depending on the tasks. This argument is handy if you have a custom base model class that has a special prediction function. Defaults to None. + temporal_window_prequery: + Whether to prequery the temporal windows as pd.DataFrame object to speed-up the stixel query. If set to True, query speed will be faster but with a moderate memory usage increase. base_model_prediction_param: Any other paramters to pass into the prediction method of the base models. e.g., base_model_prediction_param={'n_jobs':1}. @@ -1844,7 +1880,7 @@ def predict( """ - prediciton = self.predict_proba( + prediction = self.predict_proba( X_test, verbosity=verbosity, return_std=return_std, @@ -1852,10 +1888,11 @@ def predict( aggregation=aggregation, return_by_separate_ensembles=return_by_separate_ensembles, base_model_method = base_model_method, + temporal_window_prequery=temporal_window_prequery, **base_model_prediction_param ) # if return_by_separate_ensembles, this will be the dataframe for ensemble # if return_std, this wil be a tuple of mean and std of prediction # if none of these, then it ill output the mean prediction - return prediciton + return prediction diff --git a/stemflow/model/SphereAdaSTEM.py b/stemflow/model/SphereAdaSTEM.py index 9e1f1be..3000a9b 100644 --- a/stemflow/model/SphereAdaSTEM.py +++ b/stemflow/model/SphereAdaSTEM.py @@ -21,7 +21,7 @@ from ..utils.sphere_quadtree import get_one_ensemble_sphere_quadtree from ..utils.validation import ( check_base_model, - check_prediciton_aggregation, + check_prediction_aggregation, check_random_state, check_spatial_scale, check_spatio_bin_jitter_magnitude, @@ -437,7 +437,8 @@ def find_belonged_points(df, st_indexes_df, X_df, y_df): return res_list def SAC_ensemble_predict( - self, single_ensemble_df: pd.DataFrame, data: Optional[pd.DataFrame] = None + self, single_ensemble_df: pd.DataFrame, data: Optional[pd.DataFrame] = None, + temporal_window_prequery: bool = False ) -> pd.DataFrame: """A sub-module of SAC prediction function. Predict only one ensemble. diff --git a/stemflow/utils/plot_gif.py b/stemflow/utils/plot_gif.py index 51e0a10..bd778f4 100644 --- a/stemflow/utils/plot_gif.py +++ b/stemflow/utils/plot_gif.py @@ -1,11 +1,11 @@ -from typing import Tuple, Union - +from typing import Tuple, Union, Optional import numpy as np import pandas as pd -import matplotlib import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter from matplotlib.colors import Normalize +import cartopy.crs as ccrs +import cartopy.feature as cfeature def make_sample_gif( @@ -15,14 +15,15 @@ def make_sample_gif( Spatio1: str = "longitude", Spatio2: str = "latitude", Temporal1: str = "DOY", + continental_boundary: bool = True, figsize: Tuple[Union[float, int], Union[float, int]] = (18, 9), xlims: Tuple[Union[float, int], Union[float, int]] = None, ylims: Tuple[Union[float, int], Union[float, int]] = None, grid: bool = True, lng_size: int = 20, lat_size: int = 20, - xtick_interval: Union[float, int, None] = None, - ytick_interval: Union[float, int, None] = None, + xtick_interval: Union[float, int, None] = None, # used only when continental_boundary=False + ytick_interval: Union[float, int, None] = None, # used only when continental_boundary=False log_scale: bool = False, vmin: Union[float, int] = 0.0001, vmax: Union[float, int, None] = None, @@ -32,186 +33,252 @@ def make_sample_gif( fps: int = 30, cmap: str = "plasma", verbose: int = 1, + political_boundary: Optional[str] = None, # None | "country" | "province" | "both" + boundary_scale: str = "110m", # "110m" | "50m" | "10m" + boundary_color: str = "black", + boundary_lw: float = 0.4, + boundary_alpha: float = 0.7, + boundary_zorder: int = 2, + show_major_lakes: bool = True ): """ - Create a GIF visualizing spatio-temporal data using plt.imshow. + Create a GIF visualizing spatio-temporal data using imshow, with optional + Cartopy physical (continental) and political boundaries. Args: - data (pd.DataFrame): Input DataFrame, pre-filtered for the target area/time. + data (pd.DataFrame): Input DataFrame containing spatio-temporal data. file_path (str): Output GIF file path. - col (str): Column containing the values to plot. - Spatio1 (str): First spatial variable column. - Spatio2 (str): Second spatial variable column. - Temporal1 (str): Temporal variable column. - figsize (Tuple[Union[float, int], Union[float, int]]): Figure size. - xlims (Tuple[Union[float, int], Union[float, int]]): x-axis limits. - ylims (Tuple[Union[float, int], Union[float, int]]): y-axis limits. - grid (bool): Whether to display a grid. - lng_size (int): Number of longitudinal pixels (resolution). - lat_size (int): Number of latitudinal pixels (resolution). - xtick_interval (Union[float, int, None]): Interval between x-ticks. - ytick_interval (Union[float, int, None]): Interval between y-ticks. - log_scale (bool): Whether to apply a logarithmic scale to the data. - vmin (Union[float, int]): Minimum value for color scaling. - vmax (Union[float, int, None]): Maximum value for color scaling. + col (str): Column name containing the values to visualize (e.g., abundance). + Spatio1 (str): Column name for the first spatial variable (e.g., longitude). + Spatio2 (str): Column name for the second spatial variable (e.g., latitude). + Temporal1 (str): Column name for the temporal variable (e.g., DOY). + continental_boundary (bool): Whether to display physical continental outlines using Cartopy. + figsize (Tuple[Union[float, int], Union[float, int]]): Figure size in inches. + xlims (Tuple[Union[float, int], Union[float, int]]): Longitude limits (min, max). + ylims (Tuple[Union[float, int], Union[float, int]]): Latitude limits (min, max). + grid (bool): Whether to draw gridlines on the plot. + lng_size (int): Number of longitudinal grid cells (spatial resolution). + lat_size (int): Number of latitudinal grid cells (spatial resolution). + xtick_interval (Union[float, int, None]): Custom x-axis tick interval (used only if continental_boundary=False). + ytick_interval (Union[float, int, None]): Custom y-axis tick interval (used only if continental_boundary=False). + log_scale (bool): Apply logarithmic scaling to the plotted values. + vmin (Union[float, int]): Minimum value for the colormap normalization. + vmax (Union[float, int, None]): Maximum value for the colormap normalization (auto-detected if None). lightgrey_under (bool): Use light grey color for values below vmin. - adder (Union[int, float]): Value to add before log transformation. - dpi (Union[float, int]): Dots per inch for the output GIF. - fps (int): Frames per second for the GIF. - cmap (str): Colormap to use. - verbose (int): Verbosity level. + adder (Union[int, float]): Value added before log transformation to avoid log(0). + dpi (Union[float, int]): Output resolution (dots per inch). + fps (int): Frames per second for the GIF animation. + cmap (str): Matplotlib colormap name. + verbose (int): Verbosity level; 0 = silent, 1 = print progress. + political_boundary (Optional[str]): Type of political boundaries to overlay. + Options: + - None: No political boundaries + - "country": Show country borders (admin-0) + - "province": Show state/province boundaries (admin-1) + - "both": Show both country and province boundaries + boundary_scale (str): Scale for boundary data ("110m", "50m", or "10m"). + boundary_color (str): Color of physical and political boundaries. + boundary_lw (float): Line width of boundaries. + boundary_alpha (float): Transparency (alpha) of boundary lines. + boundary_zorder (int): Z-order (drawing order) of boundary layers. + show_major_lakes (bool): Whether to draw outlines of major lakes. + + Returns: + None. Saves the generated GIF to the specified file path. + + Notes: + - Spatial binning is performed using `np.digitize` to create a gridded raster. + - Each frame corresponds to a unique value in the temporal column. + - The function supports both linear and log-scaled color mapping. + - Requires `cartopy` for geographic projections and natural features. + + Example: + >>> make_sample_gif( + ... data=df, + ... file_path="output.gif", + ... col="abundance", + ... Spatio1="longitude", + ... Spatio2="latitude", + ... Temporal1="DOY", + ... political_boundary="both", + ... log_scale=True + ... ) """ - # Sort data by the temporal variable + + # Sort data by the temporal variable & make frame index data = data.sort_values(by=Temporal1) data["Temporal_indexer"], _ = pd.factorize(data[Temporal1]) + frames = data["Temporal_indexer"].nunique() - # Set x and y limits if not provided + # Spatial bounds if xlims is None: - xlims = (data[Spatio1].min(), data[Spatio1].max()) + xlims = (float(data[Spatio1].min()), float(data[Spatio1].max())) if ylims is None: - ylims = (data[Spatio2].min(), data[Spatio2].max()) + ylims = (float(data[Spatio2].min()), float(data[Spatio2].max())) - # Create spatial grids without slicing + # Binning grids (lat reversed so row 0 = max lat; therefore origin='upper') lng_grid = np.linspace(xlims[0], xlims[1], lng_size + 1) lat_grid = np.linspace(ylims[0], ylims[1], lat_size + 1)[::-1] - # Determine tick intervals - closest_set = ( - [10.0 ** i for i in np.arange(-15, 15, 1)] - + [10.0 ** i / 2 for i in np.arange(-15, 15, 1)] - ) - spatio1_base = (xlims[1] - xlims[0]) / 5 - if xtick_interval is None: - xtick_interval = min( - closest_set, - key=lambda x: np.inf if x - spatio1_base > 0 else abs(x - spatio1_base), - ) - if xtick_interval >= 1: - xtick_interval = int(xtick_interval) - - spatio2_base = (ylims[1] - ylims[0]) / 5 - if ytick_interval is None: - ytick_interval = min( - closest_set, - key=lambda x: np.inf if x - spatio2_base > 0 else abs(x - spatio2_base), - ) - if ytick_interval >= 1: - ytick_interval = int(ytick_interval) - - # Utility function to round numbers to the same decimal places - def round_to_same_decimal_places(A, B): - str_B = str(B) - if "." in str_B: - decimal_places = len(str_B.split(".")[1]) - else: - decimal_places = 0 - rounded_A = round(A, decimal_places) - if abs(rounded_A) > 1000: - formatted_A = format(rounded_A, f".{decimal_places}e") - else: - formatted_A = f"{rounded_A:.{decimal_places}f}" - return formatted_A - - # Initialize figure and axes - fig, ax = plt.subplots(figsize=figsize) - - # Set color scaling + # Color scaling if vmax is None: - vmax = ( - np.max(np.log(data[col].values + adder)) - if log_scale - else np.max(data[col].values) - ) - + vmax = (np.nanmax(np.log(data[col].values + adder)) + if log_scale else np.nanmax(data[col].values)) norm = Normalize(vmin=vmin, vmax=vmax) - # Prepare colormap + # Colormap my_cmap = plt.get_cmap(cmap) if lightgrey_under: + try: + my_cmap = my_cmap.copy() + except Exception: + pass my_cmap.set_under("lightgrey") - # Initialize the image to set up the colorbar - im = ax.imshow( - np.zeros((lat_size, lng_size)), norm=norm, cmap=my_cmap, animated=True - ) - cbar = fig.colorbar(im, ax=ax, shrink=0.5) - cbar.ax.get_yaxis().labelpad = 15 - cbar_label = f"log({col})" if log_scale else col - cbar.ax.set_ylabel(cbar_label, rotation=270) - - # Precompute tick labels and positions - x_ticks = np.arange(xlims[0], xlims[1] + xtick_interval, xtick_interval) - x_tick_labels = [round_to_same_decimal_places(val, xtick_interval) for val in x_ticks] - # Find positions of x_ticks within lng_grid - x_tick_positions = np.searchsorted(lng_grid, x_ticks, side='left') - ax.set_xticks(x_tick_positions) - ax.set_xticklabels(x_tick_labels) - - y_ticks = np.arange(ylims[0], ylims[1] + ytick_interval, ytick_interval) - y_tick_labels = [round_to_same_decimal_places(val, ytick_interval) for val in y_ticks] - # Since lat_grid is reversed, we need to account for that - y_tick_positions = lat_size - np.searchsorted(lat_grid[::-1], y_ticks, side='left') - 1 - ax.set_yticks(y_tick_positions) - ax.set_yticklabels(y_tick_labels) - - # Animation function - def animate(i): - if verbose >= 1: - print(f"Processing frame {i+1}/{frames}", end="\r") + # Figure & axes + if continental_boundary: + fig, ax = plt.subplots( + figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()} + ) + ax.set_extent([xlims[0], xlims[1], ylims[0], ylims[1]], crs=ccrs.PlateCarree()) + else: + if political_boundary: + raise ValueError( + "political_boundary requires continental_boundary=True." + ) + if show_major_lakes: + raise ValueError( + "show_lakes requires continental_boundary=True." + ) + fig, ax = plt.subplots(figsize=figsize) + ax.set_xlim(xlims) + ax.set_ylim(ylims) + # Optional custom ticks only for non-Cartopy axes + if xtick_interval is not None: + ax.set_xticks(np.arange(xlims[0], xlims[1] + xtick_interval, xtick_interval)) + if ytick_interval is not None: + ax.set_yticks(np.arange(ylims[0], ylims[1] + ytick_interval, ytick_interval)) + if grid: + ax.grid(alpha=0.5) - ax.clear() - sub = data[data["Temporal_indexer"] == i].copy() - if sub.empty: - return [] + # ONE persistent image artist (updated per frame) + if continental_boundary: + im = ax.imshow( + np.full((lat_size, lng_size), np.nan), + norm=norm, + cmap=my_cmap, + extent=[xlims[0], xlims[1], ylims[0], ylims[1]], + origin="upper", + transform=ccrs.PlateCarree(), + animated=False, + zorder=1, + interpolation="nearest", + resample=False + ) - temporal_value = sub[Temporal1].iloc[0] + # Physical outlines + ax.coastlines(resolution=boundary_scale, linewidth=boundary_lw*2, zorder=boundary_zorder) + ax.add_feature( + cfeature.LAND.with_scale(boundary_scale), + facecolor="none", + edgecolor=boundary_color, + linewidth=boundary_lw, + zorder=boundary_zorder, + ) - # Correct digitization with adjusted bins - g1 = np.digitize(sub[Spatio1], lng_grid, right=False) - 1 - g1 = np.clip(g1, 0, lng_size - 1).astype(int) + if show_major_lakes: + ax.add_feature( + cfeature.LAKES.with_scale(boundary_scale), # or "110m" for coarser + facecolor="none", + edgecolor=boundary_color, + alpha=boundary_alpha, + zorder=boundary_zorder, + ) - g2 = np.digitize(sub[Spatio2], lat_grid, right=False) - 1 - g2 = np.clip(g2, 0, lat_size - 1).astype(int) + # --- NEW: Political boundaries (admin-0 and/or admin-1) --- + if political_boundary in {"country", "both"}: + ax.add_feature( + cfeature.BORDERS.with_scale(boundary_scale), + edgecolor=boundary_color, + linewidth=boundary_lw, + alpha=boundary_alpha, + zorder=boundary_zorder, + ) - sub[f"{Spatio1}_grid"] = g1 - sub[f"{Spatio2}_grid"] = g2 + if political_boundary in {"province", "both"}: + provinces = cfeature.NaturalEarthFeature( + category="cultural", + name="admin_1_states_provinces_lines", + scale=boundary_scale, + facecolor="none", + ) + ax.add_feature( + provinces, + edgecolor=boundary_color, + linewidth=boundary_lw * 0.9, + alpha=boundary_alpha, + zorder=boundary_zorder, + ) - grouped = sub.groupby( - [f"{Spatio2}_grid", f"{Spatio1}_grid"] - )[col].mean() + if grid: + gl = ax.gridlines(draw_labels=True, linewidth=0.3, alpha=0.5) + gl.top_labels = False + gl.right_labels = False - im_data = np.full((lat_size, lng_size), np.nan) - indices = (grouped.index.get_level_values(0), grouped.index.get_level_values(1)) - values = np.log(grouped.values + adder) if log_scale else grouped.values - im_data[indices] = values + else: + im = ax.imshow( + np.full((lat_size, lng_size), np.nan), + norm=norm, + cmap=my_cmap, + extent=None, # pixel coords + origin="upper", + animated=False, + zorder=1, + ) + + # Colorbar & title + cbar = fig.colorbar(im, ax=ax, shrink=0.5) + cbar.ax.set_ylabel(f"log({col})" if log_scale else col, rotation=270, labelpad=15) + title_txt = ax.set_title("", fontsize=30) + + # Animation update: ONLY update raster + title (no ax.clear()) + def animate(i): + if verbose >= 1: + print(f"Processing frame {i+1}/{frames}", end="\r") - im = ax.imshow(im_data, norm=norm, cmap=my_cmap, animated=True) - ax.set_title(f"{Temporal1}: {temporal_value}", fontsize=30) + sub = data[data["Temporal_indexer"] == i] + im_data = np.full((lat_size, lng_size), np.nan) - # Re-apply ticks and grid in each frame - ax.set_xticks(x_tick_positions) - ax.set_xticklabels(x_tick_labels) - ax.set_yticks(y_tick_positions) - ax.set_yticklabels(y_tick_labels) + if not sub.empty: + # Bin to grid + g1 = np.digitize(sub[Spatio1].to_numpy(), lng_grid, right=False) - 1 + g1 = np.clip(g1, 0, lng_size - 1).astype(int) + g2 = np.digitize(sub[Spatio2].to_numpy(), lat_grid, right=False) - 1 + g2 = np.clip(g2, 0, lat_size - 1).astype(int) - if grid: - ax.grid(alpha=0.5) + grouped = ( + sub.assign(**{f"{Spatio1}_grid": g1, f"{Spatio2}_grid": g2}) + .groupby([f"{Spatio2}_grid", f"{Spatio1}_grid"])[col] + .mean() + ) - return [im] + if len(grouped) > 0: + idx0 = grouped.index.get_level_values(0) + idx1 = grouped.index.get_level_values(1) + vals = np.log(grouped.values + adder) if log_scale else grouped.values + im_data[idx0, idx1] = vals - frames = data["Temporal_indexer"].nunique() + temporal_value = sub[Temporal1].iloc[0] + title_txt.set_text(f"{Temporal1}: {temporal_value}") + else: + title_txt.set_text("") - # Create animation - ani = FuncAnimation( - fig, - animate, - frames=frames, - interval=1000 / fps, - blit=True, - repeat=True, - ) + im.set_data(im_data) + return (im,) + ani = FuncAnimation(fig, animate, frames=frames, interval=int(1000 / fps), blit=False, repeat=True) ani.save(file_path, dpi=dpi, writer=PillowWriter(fps=fps)) - plt.close() + plt.close(fig) if verbose >= 1: - print("\nAnimation saved successfully!") \ No newline at end of file + print("\nAnimation saved successfully!") + diff --git a/stemflow/utils/validation.py b/stemflow/utils/validation.py index 9e554db..5a6df24 100644 --- a/stemflow/utils/validation.py +++ b/stemflow/utils/validation.py @@ -377,7 +377,7 @@ def check_X_test(X_test, self): check_X_train(X_test, self) -def check_prediciton_aggregation(aggregation): +def check_prediction_aggregation(aggregation): if aggregation not in ["mean", "median"]: raise ValueError(f"aggregation must be one of 'mean' and 'median'. Got {aggregation}") diff --git a/stemflow/version.py b/stemflow/version.py index 9b102be..1436d8f 100644 --- a/stemflow/version.py +++ b/stemflow/version.py @@ -1 +1 @@ -__version__ = "1.1.5" +__version__ = "1.1.6" diff --git a/tests/test_makegif.py b/tests/test_makegif.py index 7d7e61c..d3d4d5b 100644 --- a/tests/test_makegif.py +++ b/tests/test_makegif.py @@ -71,3 +71,34 @@ def test_make_gif_changing_ranges(): ) shutil.rmtree(tmp_dir) + + +def test_make_gif_plot_political_boundary(): + + tmp_dir = "./stemflow_test_make_gif3" + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + + make_sample_gif( + fake_data, + os.path.join(tmp_dir, "FTR_IPT_dat_changing_ranges.gif"), + col="dat", + log_scale=False, + Spatio1="x", + Spatio2="y", + Temporal1="DOY", + figsize=(18, 9), + xlims=None, + ylims=None, + grid=True, + xtick_interval=None, + ytick_interval=None, + lng_size=10, + lat_size=10, + dpi=100, + fps=10, + political_boundary='both' + ) + + shutil.rmtree(tmp_dir) +