Skip to content

Commit 3f01820

Browse files
authored
Merge pull request #203 from ToFuProject/Issue202_UniformizeParamsShapes
Issue202 `_check_all_broadcastable(**kwdargs)`
2 parents 71a20db + 152ed4b commit 3f01820

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

datastock/_generic_check.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,95 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None):
550550
)
551551

552552

553+
# #############################################################################
554+
# #############################################################################
555+
# Utilities for plotting
556+
# #############################################################################
557+
558+
559+
def _check_all_broadcastable(**kwdargs):
560+
561+
# -------------------
562+
# Preliminary check
563+
# -------------------
564+
565+
dout = {}
566+
dfail = {}
567+
for k0, v0 in kwdargs.items():
568+
try:
569+
dout[k0] = np.atleast_1d(v0)
570+
except Exception as err:
571+
dfail[k0] = f"Not convertible to np.ndarray! - {v0}"
572+
573+
# Raise Exception
574+
if len(dfail) > 0:
575+
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
576+
msg = (
577+
"The following kwdargs are non-conform:\n"
578+
+ "\n".join(lstr)
579+
)
580+
raise Exception(msg)
581+
582+
# -------------------
583+
# check ndim
584+
# -------------------
585+
586+
dndim = {k0: v0.ndim for k0, v0 in dout.items() if v0.shape != (1,)}
587+
lndim = list(set(dndim.values()))
588+
589+
if len(lndim) == 0:
590+
# all scalar
591+
return {k0: v0[0] for k0, v0 in dout.items()}, None
592+
593+
elif len(lndim) == 1:
594+
ndim = lndim[0]
595+
596+
else:
597+
lstr = [f"-t {k0}: {v0}" for k0, v0 in dndim.items()]
598+
msg = (
599+
"Some keyword args have non-compatible dimensions:\n"
600+
+ "\n".join(lstr)
601+
)
602+
raise Exception(msg)
603+
604+
# -------------------
605+
# check shapes
606+
# -------------------
607+
608+
dfail = {}
609+
shapef = np.ones((ndim,), dtype=int)
610+
for k0, v0 in dout.items():
611+
612+
if v0.shape == (1,):
613+
dout[k0] = v0[0]
614+
continue
615+
616+
for ii in range(ndim):
617+
if v0.shape[ii] == 1:
618+
pass
619+
elif shapef[ii] == 1:
620+
shapef[ii] = v0.shape[ii]
621+
elif v0.shape[ii] == shapef[ii]:
622+
pass
623+
else:
624+
dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})"
625+
continue
626+
627+
shapef = tuple(shapef)
628+
629+
# raise Exception if needed
630+
if len(dfail) > 0:
631+
lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()]
632+
msg = (
633+
"The following keywords args have non-compatible shape:\n"
634+
+ "\n".join(lstr)
635+
+ f"\nReference shape: {shapef}\n"
636+
)
637+
raise Exception(msg)
638+
639+
return dout, shapef
640+
641+
553642
# #############################################################################
554643
# #############################################################################
555644
# Utilities for plotting
@@ -929,4 +1018,4 @@ def _check_cmap_vminvmax(data=None, cmap=None, vmin=None, vmax=None):
9291018
else:
9301019
vmax = nanmax
9311020

932-
return cmap, vmin, vmax
1021+
return cmap, vmin, vmax

datastock/tests/test_01_DataStock.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import matplotlib.pyplot as plt
1515

1616
# datastock-specific
17+
from .._generic_check import _check_all_broadcastable
1718
from .._class import DataStock
1819
from .._saveload import load
1920

@@ -228,6 +229,38 @@ def test02_add_data(self):
228229
def test03_add_obj(self):
229230
_add_obj(st=self.st, nc=self.nc)
230231

232+
# ------------------------
233+
# Tools
234+
# ------------------------
235+
236+
def test04_check_all_broadcastable(self):
237+
# all scalar
238+
dout, shape = _check_all_broadcastable(a=1, b=2)
239+
240+
# scalar + arrays
241+
dout, shape = _check_all_broadcastable(a=1, b=(1, 2, 3))
242+
243+
# all arrays
244+
dout, shape = _check_all_broadcastable(
245+
a=(1, 2, 3),
246+
b=(1, 2, 3),
247+
)
248+
249+
# all arrays - 2d
250+
dout, shape = _check_all_broadcastable(
251+
a=np.r_[1, 2, 3][:, None],
252+
b=np.r_[10, 20][None, :],
253+
)
254+
255+
# check flag
256+
err = False
257+
try:
258+
dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3))
259+
except Exception:
260+
err = True
261+
262+
assert err is True
263+
231264

232265
#######################################################
233266
#
@@ -639,4 +672,4 @@ def test26_saveload_coll(self, verb=False):
639672
msg = st2.__eq__(self.st, returnas=str)
640673
if msg is not True:
641674
raise Exception(msg)
642-
os.remove(pfe)
675+
os.remove(pfe)

0 commit comments

Comments
 (0)