diff --git a/datastock/_class0.py b/datastock/_class0.py index 26b909d..0c54855 100644 --- a/datastock/_class0.py +++ b/datastock/_class0.py @@ -65,7 +65,7 @@ def to_dict( ) @classmethod - def from_dict(cls, din=None, isflat=None, sep=None): + def from_dict(cls, din=None, isflat=None, sep=None, obj=None): """ Populate the instances attributes using an input dict The input dict must be properly formatted @@ -85,7 +85,9 @@ def from_dict(cls, din=None, isflat=None, sep=None): # --------------------- # Instanciate and populate - obj = cls() + if obj is None: + obj = cls() + for k0 in din.keys(): if k0 == '_ddef': if 'dobj' not in din[k0]['params'].keys(): diff --git a/datastock/_class1.py b/datastock/_class1.py index 27e2304..d4003c9 100644 --- a/datastock/_class1.py +++ b/datastock/_class1.py @@ -369,8 +369,8 @@ def remove_param( ########### @classmethod - def from_dict(cls, din=None, sep=None): - obj = super().from_dict(din=din, sep=sep) + def from_dict(cls, din=None, sep=None, obj=None): + obj = super().from_dict(din=din, sep=sep, obj=obj) obj.update() return obj diff --git a/datastock/_saveload.py b/datastock/_saveload.py index 1524c42..62053d8 100644 --- a/datastock/_saveload.py +++ b/datastock/_saveload.py @@ -199,6 +199,7 @@ def save( def load( pfe=None, cls=None, + coll=None, allow_pickle=None, sep=None, verb=None, @@ -206,16 +207,24 @@ def load( # ------------- # check inputs + # ------------- + # --------- # pfe + if not os.path.isfile(pfe): msg = f"Arg pfe must be a valid path to a file!\n\t- Provided: {pfe}" raise Exception(msg) - # cls - if cls is None: - from ._class import DataStock - cls = DataStock + # -------------- + # cls vs coll + + if coll is None: + if cls is None: + from ._class import DataStock + cls = DataStock + else: + cls = coll.__class__ if not (type(cls) is type and hasattr(cls, 'from_dict')): msg = ( @@ -224,14 +233,18 @@ def load( ) raise Exception(msg) + # ------------ # allow_pickle + allow_pickle = _generic_check._check_var( allow_pickle, 'allow_pickle', default=True, types=bool, ) + # ------- # verb + verb = _generic_check._check_var( verb, 'verb', default=True, @@ -240,11 +253,13 @@ def load( # -------------- # load flat dict + # -------------- dflat = dict(np.load(pfe, allow_pickle=allow_pickle)) # ------------------------------ # load sep from file if exists + # ------------------------------ if _KEY_SEP in dflat.keys(): # new @@ -256,6 +271,7 @@ def load( # ---------- # reshape + # ---------- dout = {} for k0, v0 in dflat.items(): @@ -310,14 +326,19 @@ def load( # ----------- # Instanciate + # ----------- - obj = cls.from_dict(dout) + coll = cls.from_dict(dout, obj=coll) + + # ----------- + # verb + # ----------- if verb: msg = f"Loaded from\n\t{pfe}" print(msg) - return obj + return coll # ################################################################# diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 2b2213d..374bab1 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -612,6 +612,16 @@ def test24_saveload(self, verb=False): st2 = load(pfe, verb=verb) # Just to check the loaded version works fine msg = st2.__eq__(self.st, returnas=str) + if msg is not True: + raise Exception(msg) + os.remove(pfe) + + def test25_saveload_coll(self, verb=False): + pfe = self.st.save(path=_PATH_OUTPUT, verb=verb, return_pfe=True) + st = DataStock() + st2 = load(pfe, coll=st, verb=verb) + # Just to check the loaded version works fine + msg = st2.__eq__(self.st, returnas=str) if msg is not True: raise Exception(msg) os.remove(pfe) \ No newline at end of file diff --git a/datastock/version.py b/datastock/version.py index ef44963..b880254 100644 --- a/datastock/version.py +++ b/datastock/version.py @@ -1,2 +1,2 @@ # Do not edit, pipeline versioning governed by git tags! -__version__ = '0.0.43' +__version__ = '0.0.44'