Skip to content

Commit 4b3e07c

Browse files
Merge pull request #408 from Blosc/fixLinspace
Fix linspace for incompatible num/shape
2 parents 383d254 + 6b0bd8f commit 4b3e07c

File tree

4 files changed

+24
-5
lines changed

4 files changed

+24
-5
lines changed

src/blosc2/ndarray.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3266,7 +3266,9 @@ def arange_fill(inputs, output, offset):
32663266

32673267

32683268
# Define a numpy linspace-like function
3269-
def linspace(start, stop, num=50, endpoint=True, dtype=np.float64, shape=None, c_order=True, **kwargs: Any):
3269+
def linspace(
3270+
start, stop, num=None, endpoint=True, dtype=np.float64, shape=None, c_order=True, **kwargs: Any
3271+
):
32703272
"""Return evenly spaced numbers over a specified interval.
32713273
32723274
This is similar to `numpy.linspace` but it returns a `NDArray`
@@ -3308,8 +3310,17 @@ def linspace_fill(inputs, output, offset):
33083310
stop_ = start_ + lout / num * (stop - start)
33093311
output[:] = np.linspace(start_, stop_, lout, endpoint=False, dtype=output.dtype)
33103312

3311-
if not shape:
3312-
shape = (num,)
3313+
if shape is None or num is None:
3314+
if shape is None and num is None:
3315+
raise ValueError("Either `shape` or `num` must be specified.")
3316+
if shape is None: # num is not None
3317+
shape = (num,)
3318+
else: # num is none
3319+
num = math.prod(shape)
3320+
3321+
# check compatibility of shape and num
3322+
if math.prod(shape) != num:
3323+
raise ValueError("The specified shape is not consistent with the specified num value")
33133324
dtype = _check_dtype(dtype)
33143325

33153326
if is_inside_new_expr():

tests/ndarray/test_evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def sample_data(request):
2727
# The jit decorator can work with any numpy or NDArray params in functions
2828
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
2929
b = np.linspace(1, 2, shape[0] * shape[1], dtype=dtype).reshape(shape)
30-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
30+
c = blosc2.linspace(-10, 10, np.prod(cshape), dtype=dtype, shape=cshape)
3131
return a, b, c, shape
3232

3333

tests/ndarray/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def sample_data(request):
2626
# The jit decorator can work with any numpy or NDArray params in functions
2727
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
2828
b = np.linspace(1, 2, shape[0] * shape[1], dtype=dtype).reshape(shape)
29-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
29+
c = blosc2.linspace(-10, 10, np.prod(cshape), dtype=dtype, shape=cshape)
3030
return a, b, c, shape
3131

3232

tests/ndarray/test_ndarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def test_linspace(ss, shape, dtype, chunks, blocks, endpoint, c_order):
187187
else:
188188
# This is chunk order, so testing is more laborious, and not really necessary
189189
pass
190+
with pytest.raises(ValueError):
191+
a = blosc2.linspace(start, stop, 10, shape=(20,)) # num incompatible with shape
192+
with pytest.raises(ValueError):
193+
a = blosc2.linspace(start, stop) # num or shape should be specified
194+
a = blosc2.linspace(start, stop, shape=(20,)) # should have length 20
195+
assert a.shape == (20,)
196+
a = blosc2.linspace(start, stop, num=20) # should have length 20
197+
assert a.shape == (20,)
190198

191199

192200
@pytest.mark.parametrize(("N", "M"), [(10, None), (10, 20), (20, 10)])

0 commit comments

Comments
 (0)