Skip to content
12 changes: 9 additions & 3 deletions datastock/_class2_interactivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,12 +459,18 @@ def _get_ix_for_refx_only_1or2d(


def get_fupdate(handle=None, dtype=None, norm=None, bstr=None):

# Note: set_xdata() and set_ydata() do not accept scalar values
# Deprecation warning since matplotlib 3.7
# see https://github.com/matplotlib/matplotlib/pull/22329
# see https://github.com/matplotlib/matplotlib/issues/28927

if dtype == 'xdata':
def func(val, handle=handle):
handle.set_xdata(val)
handle.set_xdata(np.atleast_1d(val))
elif dtype == 'ydata':
def func(val, handle=handle):
handle.set_ydata(val)
handle.set_ydata(np.atleast_1d(val))
elif dtype in ['data']: # Also works for imshow
def func(val, handle=handle):
handle.set_data(val)
Expand Down Expand Up @@ -588,4 +594,4 @@ def _update_mobile(k0=None, dmobile=None, dref=None, ddata=None):
# ddata[dmobile[k0]['data'][ii]]['data'][
# dmobile[k0]['func_slice'][ii](iref[ii])
# ]
# )
# )
122 changes: 121 additions & 1 deletion datastock/_generic_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,126 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None):
)


# #############################################################################
# #############################################################################
# Utilities for plotting
# #############################################################################


def _check_all_broadcastable(
return_full_arrays=None,
**kwdargs,
):

# -------------------
# return_full_arrays
# -------------------

return_full_arrays = _check_var(
return_full_arrays, 'return_full_arrays',
types=bool,
default=False,
)

# -------------------
# Preliminary check
# -------------------

dout = {}
dfail = {}
for k0, v0 in kwdargs.items():
try:
dout[k0] = np.atleast_1d(v0)
except Exception:
dfail[k0] = f"Not convertible to np.ndarray! - {v0}"

# Raise Exception
if len(dfail) > 0:
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
msg = (
"The following kwdargs are non-conform:\n"
+ "\n".join(lstr)
)
raise Exception(msg)

# -------------------
# check ndim
# -------------------

dndim = {k0: v0.ndim for k0, v0 in dout.items() if v0.shape != (1,)}
lndim = list(set(dndim.values()))

if len(lndim) == 0:
# all scalar
if return_full_arrays:
return dout, (1,)
else:
return {k0: v0[0] for k0, v0 in dout.items()}, None

elif len(lndim) == 1:
ndim = lndim[0]

else:
lstr = [f"-t {k0}: {v0}" for k0, v0 in dndim.items()]
msg = (
"Some keyword args have non-compatible dimensions:\n"
+ "\n".join(lstr)
)
raise Exception(msg)

# -------------------
# check shapes
# -------------------

dfail = {}
shapef = np.ones((ndim,), dtype=int)
for k0, v0 in dout.items():

if v0.shape == (1,):
continue

for ii in range(ndim):
if v0.shape[ii] == 1:
pass
elif shapef[ii] == 1:
shapef[ii] = v0.shape[ii]
elif v0.shape[ii] == shapef[ii]:
pass
else:
dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})"
continue

shapef = tuple(shapef)

# raise Exception if needed
if len(dfail) > 0:
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
msg = (
"The following keywords args have non-compatible shape:\n"
+ "\n".join(lstr)
+ f"\nReference shape: {shapef}\n"
)
raise Exception(msg)

# -------------------
# reshape output
# -------------------

if return_full_arrays is True:
for k0, v0 in dout.items():
if v0.shape == (1,):
dout[k0] = np.full(shapef, v0[0])
elif v0.shape != shapef:
dout[k0] = np.broadcast_to(v0, shapef)

else:
for k0, v0 in dout.items():
if v0.shape == (1,):
dout[k0] = v0[0]

return dout, shapef


# #############################################################################
# #############################################################################
# Utilities for plotting
Expand Down Expand Up @@ -929,4 +1049,4 @@ def _check_cmap_vminvmax(data=None, cmap=None, vmin=None, vmax=None):
else:
vmax = nanmax

return cmap, vmin, vmax
return cmap, vmin, vmax
43 changes: 42 additions & 1 deletion datastock/tests/test_01_DataStock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import matplotlib.pyplot as plt

# datastock-specific
from .._generic_check import _check_all_broadcastable
from .._class import DataStock
from .._saveload import load

Expand Down Expand Up @@ -228,6 +229,46 @@ def test02_add_data(self):
def test03_add_obj(self):
_add_obj(st=self.st, nc=self.nc)

# ------------------------
# Tools
# ------------------------

def test04_check_all_broadcastable(self):
# all scalar
dout, shape = _check_all_broadcastable(a=1, b=2)

# scalar + arrays
dout, shape = _check_all_broadcastable(a=1, b=(1, 2, 3))

# all arrays
dout, shape = _check_all_broadcastable(
a=(1, 2, 3),
b=(1, 2, 3),
)

# all arrays - 2d
dout, shape = _check_all_broadcastable(
a=np.r_[1, 2, 3][:, None],
b=np.r_[10, 20][None, :],
)

# check flag
err = False
try:
dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3))
except Exception:
err = True
assert err is True

# all arrays - mix
dout, shape = _check_all_broadcastable(
a=np.r_[1, 2, 3][:, None],
b=np.r_[10, 20][None, :],
c=3,
return_full_arrays=True,
)
assert all([v0.shape == (3, 2) for v0 in dout.values()])


#######################################################
#
Expand Down Expand Up @@ -639,4 +680,4 @@ def test26_saveload_coll(self, verb=False):
msg = st2.__eq__(self.st, returnas=str)
if msg is not True:
raise Exception(msg)
os.remove(pfe)
os.remove(pfe)
2 changes: 1 addition & 1 deletion datastock/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Do not edit, pipeline versioning governed by git tags!
__version__ = '0.0.46'
__version__ = '0.0.47'
Loading