Skip to content

Commit b49bd24

Browse files
authored
Merge pull request #206 from ToFuProject/devel
Prepare 0.0.47
2 parents 5cf2e4a + ef64a1a commit b49bd24

File tree

4 files changed

+173
-6
lines changed

4 files changed

+173
-6
lines changed

datastock/_class2_interactivity.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,18 @@ def _get_ix_for_refx_only_1or2d(
459459

460460

461461
def get_fupdate(handle=None, dtype=None, norm=None, bstr=None):
462+
463+
# Note: set_xdata() and set_ydata() do not accept scalar values
464+
# Deprecation warning since matplotlib 3.7
465+
# see https://github.com/matplotlib/matplotlib/pull/22329
466+
# see https://github.com/matplotlib/matplotlib/issues/28927
467+
462468
if dtype == 'xdata':
463469
def func(val, handle=handle):
464-
handle.set_xdata(val)
470+
handle.set_xdata(np.atleast_1d(val))
465471
elif dtype == 'ydata':
466472
def func(val, handle=handle):
467-
handle.set_ydata(val)
473+
handle.set_ydata(np.atleast_1d(val))
468474
elif dtype in ['data']: # Also works for imshow
469475
def func(val, handle=handle):
470476
handle.set_data(val)
@@ -588,4 +594,4 @@ def _update_mobile(k0=None, dmobile=None, dref=None, ddata=None):
588594
# ddata[dmobile[k0]['data'][ii]]['data'][
589595
# dmobile[k0]['func_slice'][ii](iref[ii])
590596
# ]
591-
# )
597+
# )

datastock/_generic_check.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,126 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None):
550550
)
551551

552552

553+
# #############################################################################
554+
# #############################################################################
555+
# Utilities for plotting
556+
# #############################################################################
557+
558+
559+
def _check_all_broadcastable(
560+
return_full_arrays=None,
561+
**kwdargs,
562+
):
563+
564+
# -------------------
565+
# return_full_arrays
566+
# -------------------
567+
568+
return_full_arrays = _check_var(
569+
return_full_arrays, 'return_full_arrays',
570+
types=bool,
571+
default=False,
572+
)
573+
574+
# -------------------
575+
# Preliminary check
576+
# -------------------
577+
578+
dout = {}
579+
dfail = {}
580+
for k0, v0 in kwdargs.items():
581+
try:
582+
dout[k0] = np.atleast_1d(v0)
583+
except Exception:
584+
dfail[k0] = f"Not convertible to np.ndarray! - {v0}"
585+
586+
# Raise Exception
587+
if len(dfail) > 0:
588+
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
589+
msg = (
590+
"The following kwdargs are non-conform:\n"
591+
+ "\n".join(lstr)
592+
)
593+
raise Exception(msg)
594+
595+
# -------------------
596+
# check ndim
597+
# -------------------
598+
599+
dndim = {k0: v0.ndim for k0, v0 in dout.items() if v0.shape != (1,)}
600+
lndim = list(set(dndim.values()))
601+
602+
if len(lndim) == 0:
603+
# all scalar
604+
if return_full_arrays:
605+
return dout, (1,)
606+
else:
607+
return {k0: v0[0] for k0, v0 in dout.items()}, None
608+
609+
elif len(lndim) == 1:
610+
ndim = lndim[0]
611+
612+
else:
613+
lstr = [f"-t {k0}: {v0}" for k0, v0 in dndim.items()]
614+
msg = (
615+
"Some keyword args have non-compatible dimensions:\n"
616+
+ "\n".join(lstr)
617+
)
618+
raise Exception(msg)
619+
620+
# -------------------
621+
# check shapes
622+
# -------------------
623+
624+
dfail = {}
625+
shapef = np.ones((ndim,), dtype=int)
626+
for k0, v0 in dout.items():
627+
628+
if v0.shape == (1,):
629+
continue
630+
631+
for ii in range(ndim):
632+
if v0.shape[ii] == 1:
633+
pass
634+
elif shapef[ii] == 1:
635+
shapef[ii] = v0.shape[ii]
636+
elif v0.shape[ii] == shapef[ii]:
637+
pass
638+
else:
639+
dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})"
640+
continue
641+
642+
shapef = tuple(shapef)
643+
644+
# raise Exception if needed
645+
if len(dfail) > 0:
646+
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
647+
msg = (
648+
"The following keywords args have non-compatible shape:\n"
649+
+ "\n".join(lstr)
650+
+ f"\nReference shape: {shapef}\n"
651+
)
652+
raise Exception(msg)
653+
654+
# -------------------
655+
# reshape output
656+
# -------------------
657+
658+
if return_full_arrays is True:
659+
for k0, v0 in dout.items():
660+
if v0.shape == (1,):
661+
dout[k0] = np.full(shapef, v0[0])
662+
elif v0.shape != shapef:
663+
dout[k0] = np.broadcast_to(v0, shapef)
664+
665+
else:
666+
for k0, v0 in dout.items():
667+
if v0.shape == (1,):
668+
dout[k0] = v0[0]
669+
670+
return dout, shapef
671+
672+
553673
# #############################################################################
554674
# #############################################################################
555675
# Utilities for plotting
@@ -929,4 +1049,4 @@ def _check_cmap_vminvmax(data=None, cmap=None, vmin=None, vmax=None):
9291049
else:
9301050
vmax = nanmax
9311051

932-
return cmap, vmin, vmax
1052+
return cmap, vmin, vmax

datastock/tests/test_01_DataStock.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import matplotlib.pyplot as plt
1515

1616
# datastock-specific
17+
from .._generic_check import _check_all_broadcastable
1718
from .._class import DataStock
1819
from .._saveload import load
1920

@@ -228,6 +229,46 @@ def test02_add_data(self):
228229
def test03_add_obj(self):
229230
_add_obj(st=self.st, nc=self.nc)
230231

232+
# ------------------------
233+
# Tools
234+
# ------------------------
235+
236+
def test04_check_all_broadcastable(self):
237+
# all scalar
238+
dout, shape = _check_all_broadcastable(a=1, b=2)
239+
240+
# scalar + arrays
241+
dout, shape = _check_all_broadcastable(a=1, b=(1, 2, 3))
242+
243+
# all arrays
244+
dout, shape = _check_all_broadcastable(
245+
a=(1, 2, 3),
246+
b=(1, 2, 3),
247+
)
248+
249+
# all arrays - 2d
250+
dout, shape = _check_all_broadcastable(
251+
a=np.r_[1, 2, 3][:, None],
252+
b=np.r_[10, 20][None, :],
253+
)
254+
255+
# check flag
256+
err = False
257+
try:
258+
dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3))
259+
except Exception:
260+
err = True
261+
assert err is True
262+
263+
# all arrays - mix
264+
dout, shape = _check_all_broadcastable(
265+
a=np.r_[1, 2, 3][:, None],
266+
b=np.r_[10, 20][None, :],
267+
c=3,
268+
return_full_arrays=True,
269+
)
270+
assert all([v0.shape == (3, 2) for v0 in dout.values()])
271+
231272

232273
#######################################################
233274
#
@@ -639,4 +680,4 @@ def test26_saveload_coll(self, verb=False):
639680
msg = st2.__eq__(self.st, returnas=str)
640681
if msg is not True:
641682
raise Exception(msg)
642-
os.remove(pfe)
683+
os.remove(pfe)

datastock/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Do not edit, pipeline versioning governed by git tags!
2-
__version__ = '0.0.46'
2+
__version__ = '0.0.47'

0 commit comments

Comments
 (0)