Skip to content

Commit 070341b

Browse files
authored
Merge pull request #185 from ToFuProject/Issue184_loadColl
[#184] Done
2 parents 03b6ad4 + e266fda commit 070341b

File tree

4 files changed

+43
-10
lines changed

4 files changed

+43
-10
lines changed

datastock/_class0.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def to_dict(
6565
)
6666

6767
@classmethod
68-
def from_dict(cls, din=None, isflat=None, sep=None):
68+
def from_dict(cls, din=None, isflat=None, sep=None, obj=None):
6969
""" Populate the instances attributes using an input dict
7070
7171
The input dict must be properly formatted
@@ -85,7 +85,9 @@ def from_dict(cls, din=None, isflat=None, sep=None):
8585
# ---------------------
8686
# Instanciate and populate
8787

88-
obj = cls()
88+
if obj is None:
89+
obj = cls()
90+
8991
for k0 in din.keys():
9092
if k0 == '_ddef':
9193
if 'dobj' not in din[k0]['params'].keys():

datastock/_class1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def remove_param(
369369
###########
370370

371371
@classmethod
372-
def from_dict(cls, din=None, sep=None):
373-
obj = super().from_dict(din=din, sep=sep)
372+
def from_dict(cls, din=None, sep=None, obj=None):
373+
obj = super().from_dict(din=din, sep=sep, obj=obj)
374374
obj.update()
375375
return obj
376376

datastock/_saveload.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,32 @@ def save(
199199
def load(
200200
pfe=None,
201201
cls=None,
202+
coll=None,
202203
allow_pickle=None,
203204
sep=None,
204205
verb=None,
205206
):
206207

207208
# -------------
208209
# check inputs
210+
# -------------
209211

212+
# ---------
210213
# pfe
214+
211215
if not os.path.isfile(pfe):
212216
msg = f"Arg pfe must be a valid path to a file!\n\t- Provided: {pfe}"
213217
raise Exception(msg)
214218

215-
# cls
216-
if cls is None:
217-
from ._class import DataStock
218-
cls = DataStock
219+
# --------------
220+
# cls vs coll
221+
222+
if coll is None:
223+
if cls is None:
224+
from ._class import DataStock
225+
cls = DataStock
226+
else:
227+
cls = coll.__class__
219228

220229
if not (type(cls) is type and hasattr(cls, 'from_dict')):
221230
msg = (
@@ -224,14 +233,18 @@ def load(
224233
)
225234
raise Exception(msg)
226235

236+
# ------------
227237
# allow_pickle
238+
228239
allow_pickle = _generic_check._check_var(
229240
allow_pickle, 'allow_pickle',
230241
default=True,
231242
types=bool,
232243
)
233244

245+
# -------
234246
# verb
247+
235248
verb = _generic_check._check_var(
236249
verb, 'verb',
237250
default=True,
@@ -240,11 +253,13 @@ def load(
240253

241254
# --------------
242255
# load flat dict
256+
# --------------
243257

244258
dflat = dict(np.load(pfe, allow_pickle=allow_pickle))
245259

246260
# ------------------------------
247261
# load sep from file if exists
262+
# ------------------------------
248263

249264
if _KEY_SEP in dflat.keys():
250265
# new
@@ -256,6 +271,7 @@ def load(
256271

257272
# ----------
258273
# reshape
274+
# ----------
259275

260276
dout = {}
261277
for k0, v0 in dflat.items():
@@ -310,14 +326,19 @@ def load(
310326

311327
# -----------
312328
# Instanciate
329+
# -----------
313330

314-
obj = cls.from_dict(dout)
331+
coll = cls.from_dict(dout, obj=coll)
332+
333+
# -----------
334+
# verb
335+
# -----------
315336

316337
if verb:
317338
msg = f"Loaded from\n\t{pfe}"
318339
print(msg)
319340

320-
return obj
341+
return coll
321342

322343

323344
# #################################################################

datastock/tests/test_01_DataStock.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,16 @@ def test24_saveload(self, verb=False):
612612
st2 = load(pfe, verb=verb)
613613
# Just to check the loaded version works fine
614614
msg = st2.__eq__(self.st, returnas=str)
615+
if msg is not True:
616+
raise Exception(msg)
617+
os.remove(pfe)
618+
619+
def test25_saveload_coll(self, verb=False):
620+
pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True)
621+
st = DataStock()
622+
st2 = load(pfe, coll=st, verb=verb)
623+
# Just to check the loaded version works fine
624+
msg = st2.__eq__(self.st, returnas=str)
615625
if msg is not True:
616626
raise Exception(msg)
617627
os.remove(pfe)

0 commit comments

Comments
 (0)