Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion datastock/_class0.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ def __hash__(self, *args, **kargs):

def save(
self,
pfe=None,
path=None,
name=None,
sep=None,
verb=True,
overwrite=None,
return_pfe=False,
verb=True,
):

lsep = [';', '&', '?', '#', ',', '~', '.', '-', '_']
Expand All @@ -191,9 +193,11 @@ def save(
asarray=True,
returnas='blended',
),
pfe=pfe,
sep=sep,
path=path,
name=name,
overwrite=overwrite,
clsname=self.__class__.__name__,
return_pfe=return_pfe,
verb=verb,
Expand Down
143 changes: 116 additions & 27 deletions datastock/_saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,110 @@

def save(
dflat=None,
pfe=None,
sep=None,
name=None,
path=None,
clsname=None,
overwrite=None,
return_pfe=None,
verb=None,
):
""" Save flattened dict """

# ------------
# check inputs
# ------------

# path
path = _generic_check._check_var(
path, 'path',
default=os.path.abspath('./'),
types=str,
)
path = os.path.abspath(path)
if not os.path.isdir(path):
msg = f"Arg path must be a valid path!\nProvided: {path}"
# ------------------
# pfe vs path/name

lc = [
pfe is not None,
path is not None or name is not None,
]

if np.sum(lc) > 1:
msg = (
"Saving, please provide {pfe} xor {path and/or name}!\n"
f"\t- path: {path}\n"
f"\t- name: {name}\n"
f"\t- pfe: {pfe}\n"
)
raise Exception(msg)

# clsname
clsname = _generic_check._check_var(
clsname, 'clsname',
default='DataCollection',
types=str,
)
# ------------------
# pfe vs path/name

if lc[0]:
if not isinstance(pfe, str):
msg = (
"Arg pfe must be a str ponting to a file!\n"
f"Provided: {pfe}\n"
)
raise Exception(msg)

# name
name = _generic_check._check_var(
name, 'name',
default='name',
types=str,
sdir, sfile = os.path.split(pfe)
if sdir == '':
sdir = os.path.asbpath('.')

# check path
if not os.path.isdir(sdir):
msg = (
"Arg pfe seems to have a non-valid path!\n"
"Provided: {sdir}\n"
)
raise Exception(msg)

# check file name
if not sfile.endswith('.npz'):
sfile = f"{sfile}.npz"

# re-assemble
pfe = os.path.join(sdir, sfile)

else:
# path
path = _generic_check._check_var(
path, 'path',
default=os.path.abspath('./'),
types=str,
)
path = os.path.abspath(path)
if not os.path.isdir(path):
msg = f"Arg path must be a valid path!\nProvided: {path}"
raise Exception(msg)

# clsname
clsname = _generic_check._check_var(
clsname, 'clsname',
default='DataCollection',
types=str,
)

# name
name = _generic_check._check_var(
name, 'name',
default='name',
types=str,
)

# set automatic name
user = getpass.getuser()
dt = dtm.datetime.now().strftime("%Y%m%d-%H%M%S")
name = f'{clsname}_{name}_{user}_{dt}.npz'

# pfe
pfe = os.path.join(path, name)

# ------------------
# options

# overwrite
overwrite = _generic_check._check_var(
overwrite, 'overwrite',
default=False,
types=bool,
)

# verb
Expand All @@ -79,24 +148,44 @@ def save(

# ----------------------
# save / print / return

user = getpass.getuser()
dt = dtm.datetime.now().strftime("%Y%m%d-%H%M%S")
name = f'{clsname}_{name}_{user}_{dt}.npz'
# ----------------------

# add sep
dflat[_KEY_SEP] = sep

# -----------------
# check vs existing

if os.path.isfile(pfe):
if overwrite is True:
msg = (
"Overwriting existing file:\n"
f"\t{pfe}"
)
warnings.warn(msg)
else:
msg = (
"File already existing!\n"
"\t=> use overwrite = True to overwrite\n"
f"\t{pfe}"
)
raise Exception(msg)

# --------
# save
pfe = os.path.join(path, name)
np.savez(pfe, **dflat)

np.savez(pfe, **dflat)

# -------
# print

if verb:
msg = f"Saved in:\n\t{pfe}"
msg = f"\nSaved in:\n\t{pfe}\n"
print(msg)

# -------
# return

if return_pfe is True:
return pfe

Expand Down
7 changes: 6 additions & 1 deletion datastock/tests/test_01_DataStock.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,12 @@ def test21_copy_equal(self):
def test22_get_nbytes(self):
nb, dnb = self.st.get_nbytes()

def test23_saveload(self, verb=False):
def test23_save_pfe(self, verb=False):
pfe = os.path.join(_PATH_OUTPUT, 'testsave.npz')
self.st.save(pfe=pfe, return_pfe=False)
os.remove(pfe)

def test24_saveload(self, verb=False):
pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True)
st2 = load(pfe, verb=verb)
# Just to check the loaded version works fine
Expand Down
2 changes: 1 addition & 1 deletion datastock/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Do not edit, pipeline versioning governed by git tags!
__version__ = '0.0.42'
__version__ = '0.0.43'
Loading