diff --git a/datastock/_class1_domain.py b/datastock/_class1_domain.py index d483833..f2e574e 100644 --- a/datastock/_class1_domain.py +++ b/datastock/_class1_domain.py @@ -30,11 +30,14 @@ def domain_ref( # ----------- # get indices - lvectu = sorted({v0['vect'] for v0 in domain.values()}) + lvectu = sorted({ + v0['vect'] for v0 in domain.values() + if v0.get('vect') is not None + }) for vv in lvectu: - lk0 = [k0 for k0, v0 in domain.items() if v0['vect'] == vv] + lk0 = [k0 for k0, v0 in domain.items() if v0.get('vect') == vv] for k0 in lk0: if domain[k0].get('domain') is None: @@ -61,12 +64,14 @@ def _check( # --------- # prepare + # --------- ldata = list(coll.ddata.keys()) lref = list(coll.dref.keys()) # ------------ # domain + # ------------ c0 = ( isinstance(domain, dict) @@ -80,24 +85,17 @@ def _check( ) raise Exception(msg) - # ------------ + # -------------- # check each key + # -------------- dfail = {} domain = copy.deepcopy(domain) for k0, v0 in domain.items(): - # check ref vector - kwd = {'ref': k0} if k0 in lref else {'key0': k0} - hasref, hasvect, ref, vect = coll.get_ref_vector(**kwd)[:4] - if not (hasref and ref is not None): - dfail[k0] = "No associated ref identified!" - continue - if not (hasvect and vect is not None): - dfail[k0] = "No associated ref vector identified!" - continue - + # ----------- # v0 is dict + ltyp = (list, tuple, np.ndarray) if isinstance(v0, ltyp): domain[k0] = {'domain': v0} @@ -106,21 +104,42 @@ def _check( c0 = ( isinstance(domain[k0], dict) - and any(ss in ['ind', 'domain'] for ss in domain[k0].keys()) + and any([ss in ['ind', 'domain'] for ss in domain[k0].keys()]) and ( isinstance(domain[k0].get('domain'), ltyp) or np.isscalar(domain[k0].get('domain', 0)) ) - and isinstance(domain[k0].get('ind', np.r_[0]), np.ndarray) + and isinstance(domain[k0].get('ind', np.r_[0]), (np.ndarray, int)) ) + if not c0: dfail[k0] = "must be a dict with keys ['ind', 'domain']" continue + # ---------------- + # check ref vector + + kwd = {'ref': k0} if k0 in lref else {'key0': k0} + hasref, hasvect, ref, vect = coll.get_ref_vector(**kwd)[:4] + + if not (hasref and ref is not None): + dfail[k0] = "No associated ref identified!" + continue + # vect - domain[k0]['vect'] = vect + domain[k0]['ref'] = ref + if domain[k0].get('domain') is not None: + if not (hasvect and vect is not None): + dfail[k0] = "No associated ref vector identified!" + continue + + # vect + domain[k0]['vect'] = vect + + # ------- # domain + dom = domain[k0].get('domain') if dom is not None: dom, err = _check_domain(dom) @@ -129,10 +148,15 @@ def _check( continue domain[k0]['domain'] = dom + # ----- # ind + ind = domain[k0].get('ind') if ind is not None: - vsize = coll.ddata[vect]['data'].size + if np.isscalar(ind): + ind = np.array([ind], dtype=int) + + vsize = coll.dref[ref]['size'] if ind.dtype == bool: pass elif 'int' in ind.dtype.name: @@ -151,12 +175,14 @@ def _check( # ----------- # errors + # ----------- if len(dfail) > 0: lstr = [f"\t- '{k0}': {v0}" for k0, v0 in dfail.items()] msg = ( "The following domain keys / values are not conform:\n" + "\n".join(lstr) + + f"\nProvided:\n{domain}" ) raise Exception(msg) diff --git a/datastock/_class1_interpolate.py b/datastock/_class1_interpolate.py index 97a92d2..d9236dc 100644 --- a/datastock/_class1_interpolate.py +++ b/datastock/_class1_interpolate.py @@ -88,6 +88,7 @@ def interpolate( ddata, dout, dsh_other, sli_c, sli_x, sli_v, log_log, nan0, grid, ndim, xunique, returnas, return_params, store, inplace, + domain, ) = _check( coll=coll, # interpolation base @@ -143,7 +144,7 @@ def interpolate( # adjust data and ref if xunique if xunique: - _xunique(dout) + _xunique(dout, domain=domain) # -------- # store @@ -392,9 +393,9 @@ def _check( ) # --------------------- - # get dvect from domain + # get dref_dom from domain - domain, dvect = _get_dvect( + domain, dref_dom = _get_drefdom( coll=coll, domain=domain, ref_key=ref_key, @@ -407,7 +408,7 @@ def _check( coll=coll, keys=keys, ref_key=ref_key, - dvect=dvect, + dref_dom=dref_dom, ) # -------- @@ -422,7 +423,7 @@ def _check( ) if ref_com is not None and domain is not None: - if ref_com in [coll.ddata[k0]['ref'][0] for k0 in dvect.keys()]: + if ref_com in list(dref_dom.keys()): msg = ( "Arg ref_com and domain cannot be applied to the same ref!\n" f"\t- ref_com: {ref_com}\n" @@ -440,8 +441,10 @@ def _check( x0=x0, daxis=daxis, dunits=dunits, + # ref com dref_com=dref_com, - dvect=dvect, + # domain + dref_dom=dref_dom, ) # -------------- @@ -488,6 +491,7 @@ def _check( ddata, dout, dsh_other, sli_c, sli_x, sli_v, log_log, nan0, grid, ndim, xunique, returnas, return_params, store, inplace, + domain, ) @@ -963,56 +967,59 @@ def _x01_grid( return x0, x1, refx, ix, xunique -def _get_dvect( +def _get_drefdom( coll=None, domain=None, ref_key=None, ): # ---------------- - # domain => dvect + # domain => dref_dom + + lr_ref_key = [coll.ddata[kk]['ref'][0] for kk in ref_key] if domain is not None: # get domain domain = coll.get_domain_ref(domain) - # derive dvect - lvectu = sorted({ - v0['vect'] for v0 in domain.values() if v0['vect'] not in ref_key + # derive lrefu + lrefu = sorted({ + v0['ref'] for v0 in domain.values() if v0['ref'] not in lr_ref_key }) - dvect = { - k0: [k1 for k1, v1 in domain.items() if v1['vect'] == k0] - for k0 in lvectu + # derive dref_dom + dref_dom = { + rr: [k1 for k1, v1 in domain.items() if v1['ref'] == rr] + for rr in lrefu } # check unicity of vect - dfail = {k0: v0 for k0, v0 in dvect.items() if len(v0) > 1} + dfail = {k0: v0 for k0, v0 in dref_dom.items() if len(v0) > 1} if len(dfail) > 0: lstr = [f"\t- '{k0}': {v0}" for k0, v0 in dfail.items()] msg = ( - "Some ref vector have been specified with multiple domains!\n" + "Some ref have been specified with multiple domains!\n" + "\n".join(lstr) ) raise Exception(msg) - # build final dvect - dvect = { + # build final dref_dom + dref_dom = { k0: domain[v0[0]]['ind'] - for k0, v0 in dvect.items() + for k0, v0 in dref_dom.items() } else: - dvect = None + dref_dom = None - return domain, dvect + return domain, dref_dom def _get_ddata( coll=None, keys=None, ref_key=None, - dvect=None, + dref_dom=None, ): # -------- @@ -1024,13 +1031,12 @@ def _get_ddata( data = coll.ddata[k0]['data'] # apply domain - if dvect is not None: - for k1, v1 in dvect.items(): - ref = coll.ddata[k1]['ref'][0] - if ref in coll.ddata[k0]['ref']: - ax = coll.ddata[k0]['ref'].index(ref) + if dref_dom is not None: + for rr, vr in dref_dom.items(): + if rr in coll.ddata[k0]['ref']: + ax = coll.ddata[k0]['ref'].index(rr) sli = tuple([ - v1 if ii == ax else slice(None) + vr if ii == ax else slice(None) for ii in range(data.ndim) ]) data = data[sli] @@ -1050,7 +1056,7 @@ def _get_dout( # common refs dref_com=None, # domain - dvect=None, + dref_dom=None, ): # ------------- @@ -1069,11 +1075,11 @@ def _get_dout( rd = list(coll.ddata[k0]['ref']) # apply domain - if dvect is not None: - for k1, v1 in dvect.items(): - if coll.ddata[k1]['ref'][0] in rd: - ax = rd.index(coll.ddata[k1]['ref'][0]) - sh[ax] = len(v1) if v1.dtype == int else v1.sum() + if dref_dom is not None: + for rr, vr in dref_dom.items(): + if rr in rd: + ax = rd.index(rr) + sh[ax] = len(vr) if vr.dtype == int else vr.sum() rd[ax] = None # ------------------------ @@ -1556,7 +1562,7 @@ def _interp2d( # ############################################################### -def _xunique(dout=None): +def _xunique(dout=None, domain=None): """ interpolation on a single point => eliminates a ref """ # ---------- @@ -1567,13 +1573,17 @@ def _xunique(dout=None): for k0, v0 in dout.items() } - dwrong = {k0: v0 for k0, v0 in dind.items() if len(v0) != 1} + # Number of Nones expected + nNone = 1 + len(domain) + + # check + dwrong = {k0: v0 for k0, v0 in dind.items() if len(v0) != nNone} if len(dwrong) > 0: lstr = [ f"\t- {k0}: {dout[k0]['ref']} => {v0}" for k0, v0 in dwrong.items() ] msg = ( - "Interpolation at unique point => ref should have one None:\n" + "Interpolate unique pt => ref should have nNone = 1 + {len(domain)}:\n" + "\n".join(lstr) ) raise Exception(msg) diff --git a/datastock/version.py b/datastock/version.py index b880254..9f29e9b 100644 --- a/datastock/version.py +++ b/datastock/version.py @@ -1,2 +1,2 @@ # Do not edit, pipeline versioning governed by git tags! -__version__ = '0.0.44' +__version__ = '0.0.45'