diff --git a/datastock/_class0.py b/datastock/_class0.py index 8cfc15e..26b909d 100644 --- a/datastock/_class0.py +++ b/datastock/_class0.py @@ -161,11 +161,13 @@ def __hash__(self, *args, **kargs): def save( self, + pfe=None, path=None, name=None, sep=None, - verb=True, + overwrite=None, return_pfe=False, + verb=True, ): lsep = [';', '&', '?', '#', ',', '~', '.', '-', '_'] @@ -191,9 +193,11 @@ def save( asarray=True, returnas='blended', ), + pfe=pfe, sep=sep, path=path, name=name, + overwrite=overwrite, clsname=self.__class__.__name__, return_pfe=return_pfe, verb=verb, diff --git a/datastock/_saveload.py b/datastock/_saveload.py index 187991f..1524c42 100644 --- a/datastock/_saveload.py +++ b/datastock/_saveload.py @@ -26,10 +26,12 @@ def save( dflat=None, + pfe=None, sep=None, name=None, path=None, clsname=None, + overwrite=None, return_pfe=None, verb=None, ): @@ -37,30 +39,97 @@ def save( # ------------ # check inputs + # ------------ - # path - path = _generic_check._check_var( - path, 'path', - default=os.path.abspath('./'), - types=str, - ) - path = os.path.abspath(path) - if not os.path.isdir(path): - msg = f"Arg path must be a valid path!\nProvided: {path}" + # ------------------ + # pfe vs path/name + + lc = [ + pfe is not None, + path is not None or name is not None, + ] + + if np.sum(lc) > 1: + msg = ( + "Saving, please provide {pfe} xor {path and/or name}!\n" + f"\t- path: {path}\n" + f"\t- name: {name}\n" + f"\t- pfe: {pfe}\n" + ) raise Exception(msg) - # clsname - clsname = _generic_check._check_var( - clsname, 'clsname', - default='DataCollection', - types=str, - ) + # ------------------ + # pfe vs path/name + + if lc[0]: + if not isinstance(pfe, str): + msg = ( + "Arg pfe must be a str ponting to a file!\n" + f"Provided: {pfe}\n" + ) + raise Exception(msg) - # name - name = _generic_check._check_var( - name, 'name', - default='name', - types=str, + sdir, sfile = os.path.split(pfe) + if sdir == '': + sdir = os.path.asbpath('.') + + # check path + if not os.path.isdir(sdir): + msg = ( + "Arg pfe seems to have a non-valid path!\n" + "Provided: {sdir}\n" + ) + raise Exception(msg) + + # check file name + if not sfile.endswith('.npz'): + sfile = f"{sfile}.npz" + + # re-assemble + pfe = os.path.join(sdir, sfile) + + else: + # path + path = _generic_check._check_var( + path, 'path', + default=os.path.abspath('./'), + types=str, + ) + path = os.path.abspath(path) + if not os.path.isdir(path): + msg = f"Arg path must be a valid path!\nProvided: {path}" + raise Exception(msg) + + # clsname + clsname = _generic_check._check_var( + clsname, 'clsname', + default='DataCollection', + types=str, + ) + + # name + name = _generic_check._check_var( + name, 'name', + default='name', + types=str, + ) + + # set automatic name + user = getpass.getuser() + dt = dtm.datetime.now().strftime("%Y%m%d-%H%M%S") + name = f'{clsname}_{name}_{user}_{dt}.npz' + + # pfe + pfe = os.path.join(path, name) + + # ------------------ + # options + + # overwrite + overwrite = _generic_check._check_var( + overwrite, 'overwrite', + default=False, + types=bool, ) # verb @@ -79,24 +148,44 @@ def save( # ---------------------- # save / print / return - - user = getpass.getuser() - dt = dtm.datetime.now().strftime("%Y%m%d-%H%M%S") - name = f'{clsname}_{name}_{user}_{dt}.npz' + # ---------------------- # add sep dflat[_KEY_SEP] = sep + # ----------------- + # check vs existing + + if os.path.isfile(pfe): + if overwrite is True: + msg = ( + "Overwriting existing file:\n" + f"\t{pfe}" + ) + warnings.warn(msg) + else: + msg = ( + "File already existing!\n" + "\t=> use overwrite = True to overwrite\n" + f"\t{pfe}" + ) + raise Exception(msg) + + # -------- # save - pfe = os.path.join(path, name) - np.savez(pfe, **dflat) + np.savez(pfe, **dflat) + + # ------- # print + if verb: - msg = f"Saved in:\n\t{pfe}" + msg = f"\nSaved in:\n\t{pfe}\n" print(msg) + # ------- # return + if return_pfe is True: return pfe diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 0ebd01e..2b2213d 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -602,7 +602,12 @@ def test21_copy_equal(self): def test22_get_nbytes(self): nb, dnb = self.st.get_nbytes() - def test23_saveload(self, verb=False): + def test23_save_pfe(self, verb=False): + pfe = os.path.join(_PATH_OUTPUT, 'testsave.npz') + self.st.save(pfe=pfe, return_pfe=False) + os.remove(pfe) + + def test24_saveload(self, verb=False): pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True) st2 = load(pfe, verb=verb) # Just to check the loaded version works fine diff --git a/datastock/version.py b/datastock/version.py index 0b42797..ef44963 100644 --- a/datastock/version.py +++ b/datastock/version.py @@ -1,2 +1,2 @@ # Do not edit, pipeline versioning governed by git tags! -__version__ = '0.0.42' +__version__ = '0.0.43'