Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions datastock/_class0.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def to_dict(
)

@classmethod
def from_dict(cls, din=None, isflat=None, sep=None):
def from_dict(cls, din=None, isflat=None, sep=None, obj=None):
""" Populate the instances attributes using an input dict

The input dict must be properly formatted
Expand All @@ -85,7 +85,9 @@ def from_dict(cls, din=None, isflat=None, sep=None):
# ---------------------
# Instanciate and populate

obj = cls()
if obj is None:
obj = cls()

for k0 in din.keys():
if k0 == '_ddef':
if 'dobj' not in din[k0]['params'].keys():
Expand Down
4 changes: 2 additions & 2 deletions datastock/_class1.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def remove_param(
###########

@classmethod
def from_dict(cls, din=None, sep=None):
obj = super().from_dict(din=din, sep=sep)
def from_dict(cls, din=None, sep=None, obj=None):
obj = super().from_dict(din=din, sep=sep, obj=obj)
obj.update()
return obj

Expand Down
33 changes: 27 additions & 6 deletions datastock/_saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,32 @@ def save(
def load(
pfe=None,
cls=None,
coll=None,
allow_pickle=None,
sep=None,
verb=None,
):

# -------------
# check inputs
# -------------

# ---------
# pfe

if not os.path.isfile(pfe):
msg = f"Arg pfe must be a valid path to a file!\n\t- Provided: {pfe}"
raise Exception(msg)

# cls
if cls is None:
from ._class import DataStock
cls = DataStock
# --------------
# cls vs coll

if coll is None:
if cls is None:
from ._class import DataStock
cls = DataStock
else:
cls = coll.__class__

if not (type(cls) is type and hasattr(cls, 'from_dict')):
msg = (
Expand All @@ -224,14 +233,18 @@ def load(
)
raise Exception(msg)

# ------------
# allow_pickle

allow_pickle = _generic_check._check_var(
allow_pickle, 'allow_pickle',
default=True,
types=bool,
)

# -------
# verb

verb = _generic_check._check_var(
verb, 'verb',
default=True,
Expand All @@ -240,11 +253,13 @@ def load(

# --------------
# load flat dict
# --------------

dflat = dict(np.load(pfe, allow_pickle=allow_pickle))

# ------------------------------
# load sep from file if exists
# ------------------------------

if _KEY_SEP in dflat.keys():
# new
Expand All @@ -256,6 +271,7 @@ def load(

# ----------
# reshape
# ----------

dout = {}
for k0, v0 in dflat.items():
Expand Down Expand Up @@ -310,14 +326,19 @@ def load(

# -----------
# Instanciate
# -----------

obj = cls.from_dict(dout)
coll = cls.from_dict(dout, obj=coll)

# -----------
# verb
# -----------

if verb:
msg = f"Loaded from\n\t{pfe}"
print(msg)

return obj
return coll


# #################################################################
Expand Down
10 changes: 10 additions & 0 deletions datastock/tests/test_01_DataStock.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,16 @@ def test24_saveload(self, verb=False):
st2 = load(pfe, verb=verb)
# Just to check the loaded version works fine
msg = st2.__eq__(self.st, returnas=str)
if msg is not True:
raise Exception(msg)
os.remove(pfe)

def test25_saveload_coll(self, verb=False):
pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True)
st = DataStock()
st2 = load(pfe, coll=st, verb=verb)
# Just to check the loaded version works fine
msg = st2.__eq__(self.st, returnas=str)
if msg is not True:
raise Exception(msg)
os.remove(pfe)
Loading