Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion datastock/_generic_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,95 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None):
)


# #############################################################################
# #############################################################################
# Utilities for plotting
# #############################################################################


def _check_all_broadcastable(**kwdargs):

# -------------------
# Preliminary check
# -------------------

dout = {}
dfail = {}
for k0, v0 in kwdargs.items():
try:
dout[k0] = np.atleast_1d(v0)
except Exception as err:
dfail[k0] = f"Not convertible to np.ndarray! - {v0}"

# Raise Exception
if len(dfail) > 0:
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
msg = (
"The following kwdargs are non-conform:\n"
+ "\n".join(lstr)
)
raise Exception(msg)

# -------------------
# check ndim
# -------------------

dndim = {k0: v0.ndim for k0, v0 in dout.items() if v0.shape != (1,)}
lndim = list(set(dndim.values()))

if len(lndim) == 0:
# all scalar
return {k0: v0[0] for k0, v0 in dout.items()}, None

elif len(lndim) == 1:
ndim = lndim[0]

else:
lstr = [f"-t {k0}: {v0}" for k0, v0 in dndim.items()]
msg = (
"Some keyword args have non-compatible dimensions:\n"
+ "\n".join(lstr)
)
raise Exception(msg)

# -------------------
# check shapes
# -------------------

dfail = {}
shapef = np.ones((ndim,), dtype=int)
for k0, v0 in dout.items():

if v0.shape == (1,):
dout[k0] = v0[0]
continue

for ii in range(ndim):
if v0.shape[ii] == 1:
pass
elif shapef[ii] == 1:
shapef[ii] = v0.shape[ii]
elif v0.shape[ii] == shapef[ii]:
pass
else:
dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})"
continue

shapef = tuple(shapef)

# raise Exception if needed
if len(dfail) > 0:
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
msg = (
"The following keywords args have non-compatible shape:\n"
+ "\n".join(lstr)
+ f"\nReference shape: {shapef}\n"
)
raise Exception(msg)

return dout, shapef


# #############################################################################
# #############################################################################
# Utilities for plotting
Expand Down Expand Up @@ -929,4 +1018,4 @@ def _check_cmap_vminvmax(data=None, cmap=None, vmin=None, vmax=None):
else:
vmax = nanmax

return cmap, vmin, vmax
return cmap, vmin, vmax
35 changes: 34 additions & 1 deletion datastock/tests/test_01_DataStock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import matplotlib.pyplot as plt

# datastock-specific
from .._generic_check import _check_all_broadcastable
from .._class import DataStock
from .._saveload import load

Expand Down Expand Up @@ -228,6 +229,38 @@ def test02_add_data(self):
def test03_add_obj(self):
_add_obj(st=self.st, nc=self.nc)

# ------------------------
# Tools
# ------------------------

def test04_check_all_broadcastable(self):
# all scalar
dout, shape = _check_all_broadcastable(a=1, b=2)

# scalar + arrays
dout, shape = _check_all_broadcastable(a=1, b=(1, 2, 3))

# all arrays
dout, shape = _check_all_broadcastable(
a=(1, 2, 3),
b=(1, 2, 3),
)

# all arrays - 2d
dout, shape = _check_all_broadcastable(
a=np.r_[1, 2, 3][:, None],
b=np.r_[10, 20][None, :],
)

# check flag
err = False
try:
dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3))
except Exception:
err = True

assert err is True


#######################################################
#
Expand Down Expand Up @@ -639,4 +672,4 @@ def test26_saveload_coll(self, verb=False):
msg = st2.__eq__(self.st, returnas=str)
if msg is not True:
raise Exception(msg)
os.remove(pfe)
os.remove(pfe)
Loading