diff --git a/datastock/_class2_interactivity.py b/datastock/_class2_interactivity.py index 9cc62e0..454726b 100644 --- a/datastock/_class2_interactivity.py +++ b/datastock/_class2_interactivity.py @@ -459,12 +459,18 @@ def _get_ix_for_refx_only_1or2d( def get_fupdate(handle=None, dtype=None, norm=None, bstr=None): + + # Note: set_xdata() and set_ydata() do not accept scalar values + # Deprecation warning since matplotlib 3.7 + # see https://github.com/matplotlib/matplotlib/pull/22329 + # see https://github.com/matplotlib/matplotlib/issues/28927 + if dtype == 'xdata': def func(val, handle=handle): - handle.set_xdata(val) + handle.set_xdata(np.atleast_1d(val)) elif dtype == 'ydata': def func(val, handle=handle): - handle.set_ydata(val) + handle.set_ydata(np.atleast_1d(val)) elif dtype in ['data']: # Also works for imshow def func(val, handle=handle): handle.set_data(val) @@ -588,4 +594,4 @@ def _update_mobile(k0=None, dmobile=None, dref=None, ddata=None): # ddata[dmobile[k0]['data'][ii]]['data'][ # dmobile[k0]['func_slice'][ii](iref[ii]) # ] - # ) \ No newline at end of file + # ) diff --git a/datastock/_generic_check.py b/datastock/_generic_check.py index c864655..7e4716c 100644 --- a/datastock/_generic_check.py +++ b/datastock/_generic_check.py @@ -550,6 +550,126 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None): ) +# ############################################################################# +# ############################################################################# +# Utilities for plotting +# ############################################################################# + + +def _check_all_broadcastable( + return_full_arrays=None, + **kwdargs, +): + + # ------------------- + # return_full_arrays + # ------------------- + + return_full_arrays = _check_var( + return_full_arrays, 'return_full_arrays', + types=bool, + default=False, + ) + + # ------------------- + # Preliminary check + # ------------------- + + dout = {} + dfail = {} + for k0, v0 in kwdargs.items(): + try: + dout[k0] = np.atleast_1d(v0) + except Exception: + dfail[k0] = f"Not convertible to np.ndarray! - {v0}" + + # Raise Exception + if len(dfail) > 0: + lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()] + msg = ( + "The following kwdargs are non-conform:\n" + + "\n".join(lstr) + ) + raise Exception(msg) + + # ------------------- + # check ndim + # ------------------- + + dndim = {k0: v0.ndim for k0, v0 in dout.items() if v0.shape != (1,)} + lndim = list(set(dndim.values())) + + if len(lndim) == 0: + # all scalar + if return_full_arrays: + return dout, (1,) + else: + return {k0: v0[0] for k0, v0 in dout.items()}, None + + elif len(lndim) == 1: + ndim = lndim[0] + + else: + lstr = [f"-t {k0}: {v0}" for k0, v0 in dndim.items()] + msg = ( + "Some keyword args have non-compatible dimensions:\n" + + "\n".join(lstr) + ) + raise Exception(msg) + + # ------------------- + # check shapes + # ------------------- + + dfail = {} + shapef = np.ones((ndim,), dtype=int) + for k0, v0 in dout.items(): + + if v0.shape == (1,): + continue + + for ii in range(ndim): + if v0.shape[ii] == 1: + pass + elif shapef[ii] == 1: + shapef[ii] = v0.shape[ii] + elif v0.shape[ii] == shapef[ii]: + pass + else: + dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})" + continue + + shapef = tuple(shapef) + + # raise Exception if needed + if len(dfail) > 0: + lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()] + msg = ( + "The following keywords args have non-compatible shape:\n" + + "\n".join(lstr) + + f"\nReference shape: {shapef}\n" + ) + raise Exception(msg) + + # ------------------- + # reshape output + # ------------------- + + if return_full_arrays is True: + for k0, v0 in dout.items(): + if v0.shape == (1,): + dout[k0] = np.full(shapef, v0[0]) + elif v0.shape != shapef: + dout[k0] = np.broadcast_to(v0, shapef) + + else: + for k0, v0 in dout.items(): + if v0.shape == (1,): + dout[k0] = v0[0] + + return dout, shapef + + # ############################################################################# # ############################################################################# # Utilities for plotting @@ -929,4 +1049,4 @@ def _check_cmap_vminvmax(data=None, cmap=None, vmin=None, vmax=None): else: vmax = nanmax - return cmap, vmin, vmax \ No newline at end of file + return cmap, vmin, vmax diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 59da29a..93a969e 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -14,6 +14,7 @@ import matplotlib.pyplot as plt # datastock-specific +from .._generic_check import _check_all_broadcastable from .._class import DataStock from .._saveload import load @@ -228,6 +229,46 @@ def test02_add_data(self): def test03_add_obj(self): _add_obj(st=self.st, nc=self.nc) + # ------------------------ + # Tools + # ------------------------ + + def test04_check_all_broadcastable(self): + # all scalar + dout, shape = _check_all_broadcastable(a=1, b=2) + + # scalar + arrays + dout, shape = _check_all_broadcastable(a=1, b=(1, 2, 3)) + + # all arrays + dout, shape = _check_all_broadcastable( + a=(1, 2, 3), + b=(1, 2, 3), + ) + + # all arrays - 2d + dout, shape = _check_all_broadcastable( + a=np.r_[1, 2, 3][:, None], + b=np.r_[10, 20][None, :], + ) + + # check flag + err = False + try: + dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3)) + except Exception: + err = True + assert err is True + + # all arrays - mix + dout, shape = _check_all_broadcastable( + a=np.r_[1, 2, 3][:, None], + b=np.r_[10, 20][None, :], + c=3, + return_full_arrays=True, + ) + assert all([v0.shape == (3, 2) for v0 in dout.values()]) + ####################################################### # @@ -639,4 +680,4 @@ def test26_saveload_coll(self, verb=False): msg = st2.__eq__(self.st, returnas=str) if msg is not True: raise Exception(msg) - os.remove(pfe) \ No newline at end of file + os.remove(pfe) diff --git a/datastock/version.py b/datastock/version.py index 934bed1..c59cfb2 100644 --- a/datastock/version.py +++ b/datastock/version.py @@ -1,2 +1,2 @@ # Do not edit, pipeline versioning governed by git tags! -__version__ = '0.0.46' +__version__ = '0.0.47'