Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions datastock/_generic_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,20 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None):
# #############################################################################


def _check_all_broadcastable(**kwdargs):
def _check_all_broadcastable(
return_full_arrays=None,
**kwdargs,
):

# -------------------
# return_full_arrays
# -------------------

return_full_arrays = _check_var(
return_full_arrays, 'return_full_arrays',
types=bool,
default=False,
)

# -------------------
# Preliminary check
Expand All @@ -567,7 +580,7 @@ def _check_all_broadcastable(**kwdargs):
for k0, v0 in kwdargs.items():
try:
dout[k0] = np.atleast_1d(v0)
except Exception as err:
except Exception:
dfail[k0] = f"Not convertible to np.ndarray! - {v0}"

# Raise Exception
Expand All @@ -588,7 +601,10 @@ def _check_all_broadcastable(**kwdargs):

if len(lndim) == 0:
# all scalar
return {k0: v0[0] for k0, v0 in dout.items()}, None
if return_full_arrays:
return dout, (1,)
else:
return {k0: v0[0] for k0, v0 in dout.items()}, None

elif len(lndim) == 1:
ndim = lndim[0]
Expand All @@ -610,7 +626,6 @@ def _check_all_broadcastable(**kwdargs):
for k0, v0 in dout.items():

if v0.shape == (1,):
dout[k0] = v0[0]
continue

for ii in range(ndim):
Expand All @@ -636,6 +651,22 @@ def _check_all_broadcastable(**kwdargs):
)
raise Exception(msg)

# -------------------
# reshape output
# -------------------

if return_full_arrays is True:
for k0, v0 in dout.items():
if v0.shape == (1,):
dout[k0] = np.full(shapef, v0[0])
elif v0.shape != shapef:
dout[k0] = np.broadcast_to(v0, shapef)

else:
for k0, v0 in dout.items():
if v0.shape == (1,):
dout[k0] = v0[0]

return dout, shapef


Expand Down
10 changes: 9 additions & 1 deletion datastock/tests/test_01_DataStock.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,17 @@ def test04_check_all_broadcastable(self):
dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3))
except Exception:
err = True

assert err is True

# all arrays - mix
dout, shape = _check_all_broadcastable(
a=np.r_[1, 2, 3][:, None],
b=np.r_[10, 20][None, :],
c=3,
return_full_arrays=True,
)
assert all([v0.shape == (3, 2) for v0 in dout.values()])


#######################################################
#
Expand Down
Loading