diff --git a/datastock/_generic_check.py b/datastock/_generic_check.py index ee31ee6..7e4716c 100644 --- a/datastock/_generic_check.py +++ b/datastock/_generic_check.py @@ -556,7 +556,20 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None): # ############################################################################# -def _check_all_broadcastable(**kwdargs): +def _check_all_broadcastable( + return_full_arrays=None, + **kwdargs, +): + + # ------------------- + # return_full_arrays + # ------------------- + + return_full_arrays = _check_var( + return_full_arrays, 'return_full_arrays', + types=bool, + default=False, + ) # ------------------- # Preliminary check @@ -567,7 +580,7 @@ def _check_all_broadcastable(**kwdargs): for k0, v0 in kwdargs.items(): try: dout[k0] = np.atleast_1d(v0) - except Exception as err: + except Exception: dfail[k0] = f"Not convertible to np.ndarray! - {v0}" # Raise Exception @@ -588,7 +601,10 @@ def _check_all_broadcastable(**kwdargs): if len(lndim) == 0: # all scalar - return {k0: v0[0] for k0, v0 in dout.items()}, None + if return_full_arrays: + return dout, (1,) + else: + return {k0: v0[0] for k0, v0 in dout.items()}, None elif len(lndim) == 1: ndim = lndim[0] @@ -610,7 +626,6 @@ def _check_all_broadcastable(**kwdargs): for k0, v0 in dout.items(): if v0.shape == (1,): - dout[k0] = v0[0] continue for ii in range(ndim): @@ -636,6 +651,22 @@ def _check_all_broadcastable(**kwdargs): ) raise Exception(msg) + # ------------------- + # reshape output + # ------------------- + + if return_full_arrays is True: + for k0, v0 in dout.items(): + if v0.shape == (1,): + dout[k0] = np.full(shapef, v0[0]) + elif v0.shape != shapef: + dout[k0] = np.broadcast_to(v0, shapef) + + else: + for k0, v0 in dout.items(): + if v0.shape == (1,): + dout[k0] = v0[0] + return dout, shapef diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 3e39bd0..93a969e 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -258,9 +258,17 @@ def test04_check_all_broadcastable(self): dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3)) except Exception: err = True - assert err is True + # all arrays - mix + dout, shape = _check_all_broadcastable( + a=np.r_[1, 2, 3][:, None], + b=np.r_[10, 20][None, :], + c=3, + return_full_arrays=True, + ) + assert all([v0.shape == (3, 2) for v0 in dout.values()]) + ####################################################### #