Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions datastock/_class1_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ def domain_ref(
# -----------
# get indices

lvectu = sorted({v0['vect'] for v0 in domain.values()})
lvectu = sorted({
v0['vect'] for v0 in domain.values()
if v0.get('vect') is not None
})

for vv in lvectu:

lk0 = [k0 for k0, v0 in domain.items() if v0['vect'] == vv]
lk0 = [k0 for k0, v0 in domain.items() if v0.get('vect') == vv]
for k0 in lk0:

if domain[k0].get('domain') is None:
Expand All @@ -61,12 +64,14 @@ def _check(

# ---------
# prepare
# ---------

ldata = list(coll.ddata.keys())
lref = list(coll.dref.keys())

# ------------
# domain
# ------------

c0 = (
isinstance(domain, dict)
Expand All @@ -80,24 +85,17 @@ def _check(
)
raise Exception(msg)

# ------------
# --------------
# check each key
# --------------

dfail = {}
domain = copy.deepcopy(domain)
for k0, v0 in domain.items():

# check ref vector
kwd = {'ref': k0} if k0 in lref else {'key0': k0}
hasref, hasvect, ref, vect = coll.get_ref_vector(**kwd)[:4]
if not (hasref and ref is not None):
dfail[k0] = "No associated ref identified!"
continue
if not (hasvect and vect is not None):
dfail[k0] = "No associated ref vector identified!"
continue

# -----------
# v0 is dict

ltyp = (list, tuple, np.ndarray)
if isinstance(v0, ltyp):
domain[k0] = {'domain': v0}
Expand All @@ -106,21 +104,42 @@ def _check(

c0 = (
isinstance(domain[k0], dict)
and any(ss in ['ind', 'domain'] for ss in domain[k0].keys())
and any([ss in ['ind', 'domain'] for ss in domain[k0].keys()])
and (
isinstance(domain[k0].get('domain'), ltyp)
or np.isscalar(domain[k0].get('domain', 0))
)
and isinstance(domain[k0].get('ind', np.r_[0]), np.ndarray)
and isinstance(domain[k0].get('ind', np.r_[0]), (np.ndarray, int))
)

if not c0:
dfail[k0] = "must be a dict with keys ['ind', 'domain']"
continue

# ----------------
# check ref vector

kwd = {'ref': k0} if k0 in lref else {'key0': k0}
hasref, hasvect, ref, vect = coll.get_ref_vector(**kwd)[:4]

if not (hasref and ref is not None):
dfail[k0] = "No associated ref identified!"
continue

# vect
domain[k0]['vect'] = vect
domain[k0]['ref'] = ref

if domain[k0].get('domain') is not None:
if not (hasvect and vect is not None):
dfail[k0] = "No associated ref vector identified!"
continue

# vect
domain[k0]['vect'] = vect

# -------
# domain

dom = domain[k0].get('domain')
if dom is not None:
dom, err = _check_domain(dom)
Expand All @@ -129,10 +148,15 @@ def _check(
continue
domain[k0]['domain'] = dom

# -----
# ind

ind = domain[k0].get('ind')
if ind is not None:
vsize = coll.ddata[vect]['data'].size
if np.isscalar(ind):
ind = np.array([ind], dtype=int)

vsize = coll.dref[ref]['size']
if ind.dtype == bool:
pass
elif 'int' in ind.dtype.name:
Expand All @@ -151,12 +175,14 @@ def _check(

# -----------
# errors
# -----------

if len(dfail) > 0:
lstr = [f"\t- '{k0}': {v0}" for k0, v0 in dfail.items()]
msg = (
"The following domain keys / values are not conform:\n"
+ "\n".join(lstr)
+ f"\nProvided:\n{domain}"
)
raise Exception(msg)

Expand Down
84 changes: 47 additions & 37 deletions datastock/_class1_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def interpolate(
ddata, dout, dsh_other, sli_c, sli_x, sli_v,
log_log, nan0, grid, ndim, xunique,
returnas, return_params, store, inplace,
domain,
) = _check(
coll=coll,
# interpolation base
Expand Down Expand Up @@ -143,7 +144,7 @@ def interpolate(
# adjust data and ref if xunique

if xunique:
_xunique(dout)
_xunique(dout, domain=domain)

# --------
# store
Expand Down Expand Up @@ -392,9 +393,9 @@ def _check(
)

# ---------------------
# get dvect from domain
# get dref_dom from domain

domain, dvect = _get_dvect(
domain, dref_dom = _get_drefdom(
coll=coll,
domain=domain,
ref_key=ref_key,
Expand All @@ -407,7 +408,7 @@ def _check(
coll=coll,
keys=keys,
ref_key=ref_key,
dvect=dvect,
dref_dom=dref_dom,
)

# --------
Expand All @@ -422,7 +423,7 @@ def _check(
)

if ref_com is not None and domain is not None:
if ref_com in [coll.ddata[k0]['ref'][0] for k0 in dvect.keys()]:
if ref_com in list(dref_dom.keys()):
msg = (
"Arg ref_com and domain cannot be applied to the same ref!\n"
f"\t- ref_com: {ref_com}\n"
Expand All @@ -440,8 +441,10 @@ def _check(
x0=x0,
daxis=daxis,
dunits=dunits,
# ref com
dref_com=dref_com,
dvect=dvect,
# domain
dref_dom=dref_dom,
)

# --------------
Expand Down Expand Up @@ -488,6 +491,7 @@ def _check(
ddata, dout, dsh_other, sli_c, sli_x, sli_v,
log_log, nan0, grid, ndim, xunique,
returnas, return_params, store, inplace,
domain,
)


Expand Down Expand Up @@ -963,56 +967,59 @@ def _x01_grid(
return x0, x1, refx, ix, xunique


def _get_dvect(
def _get_drefdom(
coll=None,
domain=None,
ref_key=None,
):
# ----------------
# domain => dvect
# domain => dref_dom

lr_ref_key = [coll.ddata[kk]['ref'][0] for kk in ref_key]

if domain is not None:

# get domain
domain = coll.get_domain_ref(domain)

# derive dvect
lvectu = sorted({
v0['vect'] for v0 in domain.values() if v0['vect'] not in ref_key
# derive lrefu
lrefu = sorted({
v0['ref'] for v0 in domain.values() if v0['ref'] not in lr_ref_key
})

dvect = {
k0: [k1 for k1, v1 in domain.items() if v1['vect'] == k0]
for k0 in lvectu
# derive dref_dom
dref_dom = {
rr: [k1 for k1, v1 in domain.items() if v1['ref'] == rr]
for rr in lrefu
}

# check unicity of vect
dfail = {k0: v0 for k0, v0 in dvect.items() if len(v0) > 1}
dfail = {k0: v0 for k0, v0 in dref_dom.items() if len(v0) > 1}
if len(dfail) > 0:
lstr = [f"\t- '{k0}': {v0}" for k0, v0 in dfail.items()]
msg = (
"Some ref vector have been specified with multiple domains!\n"
"Some ref have been specified with multiple domains!\n"
+ "\n".join(lstr)
)
raise Exception(msg)

# build final dvect
dvect = {
# build final dref_dom
dref_dom = {
k0: domain[v0[0]]['ind']
for k0, v0 in dvect.items()
for k0, v0 in dref_dom.items()
}

else:
dvect = None
dref_dom = None

return domain, dvect
return domain, dref_dom


def _get_ddata(
coll=None,
keys=None,
ref_key=None,
dvect=None,
dref_dom=None,
):

# --------
Expand All @@ -1024,13 +1031,12 @@ def _get_ddata(
data = coll.ddata[k0]['data']

# apply domain
if dvect is not None:
for k1, v1 in dvect.items():
ref = coll.ddata[k1]['ref'][0]
if ref in coll.ddata[k0]['ref']:
ax = coll.ddata[k0]['ref'].index(ref)
if dref_dom is not None:
for rr, vr in dref_dom.items():
if rr in coll.ddata[k0]['ref']:
ax = coll.ddata[k0]['ref'].index(rr)
sli = tuple([
v1 if ii == ax else slice(None)
vr if ii == ax else slice(None)
for ii in range(data.ndim)
])
data = data[sli]
Expand All @@ -1050,7 +1056,7 @@ def _get_dout(
# common refs
dref_com=None,
# domain
dvect=None,
dref_dom=None,
):

# -------------
Expand All @@ -1069,11 +1075,11 @@ def _get_dout(
rd = list(coll.ddata[k0]['ref'])

# apply domain
if dvect is not None:
for k1, v1 in dvect.items():
if coll.ddata[k1]['ref'][0] in rd:
ax = rd.index(coll.ddata[k1]['ref'][0])
sh[ax] = len(v1) if v1.dtype == int else v1.sum()
if dref_dom is not None:
for rr, vr in dref_dom.items():
if rr in rd:
ax = rd.index(rr)
sh[ax] = len(vr) if vr.dtype == int else vr.sum()
rd[ax] = None

# ------------------------
Expand Down Expand Up @@ -1556,7 +1562,7 @@ def _interp2d(
# ###############################################################


def _xunique(dout=None):
def _xunique(dout=None, domain=None):
""" interpolation on a single point => eliminates a ref """

# ----------
Expand All @@ -1567,13 +1573,17 @@ def _xunique(dout=None):
for k0, v0 in dout.items()
}

dwrong = {k0: v0 for k0, v0 in dind.items() if len(v0) != 1}
# Number of Nones expected
nNone = 1 + len(domain)

# check
dwrong = {k0: v0 for k0, v0 in dind.items() if len(v0) != nNone}
if len(dwrong) > 0:
lstr = [
f"\t- {k0}: {dout[k0]['ref']} => {v0}" for k0, v0 in dwrong.items()
]
msg = (
"Interpolation at unique point => ref should have one None:\n"
"Interpolate unique pt => ref should have nNone = 1 + {len(domain)}:\n"
+ "\n".join(lstr)
)
raise Exception(msg)
Expand Down
2 changes: 1 addition & 1 deletion datastock/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Do not edit, pipeline versioning governed by git tags!
__version__ = '0.0.44'
__version__ = '0.0.45'
Loading