Skip to content

Commit 9e973ec

Browse files
committed
Fix numerical stability of linspace, add __bool__ for LazyExpr
1 parent 325c3e2 commit 9e973ec

File tree

3 files changed

+88
-36
lines changed

3 files changed

+88
-36
lines changed

src/blosc2/__init__.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,22 @@ class Tuner(Enum):
199199
uint64,
200200
)
201201

202+
DEFAULT_COMPLEX = complex128
203+
"""
204+
Default complex floating dtype."""
205+
206+
DEFAULT_FLOAT = float64
207+
"""
208+
Default real floating dtype."""
209+
210+
DEFAULT_INT = int64
211+
"""
212+
Default integer dtype."""
213+
214+
DEFAULT_INDEX = int64
215+
"""
216+
Default indexing dtype."""
217+
202218

203219
class Info:
204220
def __init__(self, **kwargs):
@@ -218,10 +234,10 @@ def __array_namespace_info__() -> Info:
218234
},
219235
default_device=None,
220236
default_dtypes={
221-
"real floating": float64,
222-
"complex floating": complex128,
223-
"integral": int64,
224-
"indexing": int64,
237+
"real floating": DEFAULT_FLOAT,
238+
"complex floating": DEFAULT_COMPLEX,
239+
"integral": DEFAULT_INT,
240+
"indexing": DEFAULT_INDEX,
225241
},
226242
dtypes={
227243
"bool": bool_,
@@ -455,6 +471,11 @@ def __array_namespace_info__() -> Info:
455471
"MIN_HEADER_LENGTH",
456472
"VERSION_DATE",
457473
"VERSION_STRING",
474+
# Default dtypes
475+
"DEFAULT_COMPLEX",
476+
"DEFAULT_FLOAT",
477+
"DEFAULT_INDEX",
478+
"DEFAULT_INT",
458479
# Mathematical constants
459480
"e",
460481
"pi",

src/blosc2/lazyexpr.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,11 @@ def to_cframe(self) -> bytes:
489489
"""
490490
return self.compute().to_cframe()
491491

492+
def __bool__(self) -> bool:
493+
if math.prod(self.shape) != 1:
494+
raise ValueError(f"The truth value of a LazyArray of shape {self.shape} is ambiguous.")
495+
return bool(self[()])
496+
492497

493498
def convert_inputs(inputs):
494499
if not inputs or len(inputs) == 0:
@@ -2540,7 +2545,7 @@ def __ror__(self, value):
25402545
return self.update_expr(new_op=(value, "|", self))
25412546

25422547
def __invert__(self):
2543-
return self.update_expr(new_op=(self, "~", None))
2548+
return self.update_expr(new_op=(None, "~", self))
25442549

25452550
def __pow__(self, value):
25462551
return self.update_expr(new_op=(self, "**", value))

src/blosc2/ndarray.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3635,7 +3635,7 @@ def ones(shape: int | tuple | list, dtype: np.dtype | str = None, **kwargs: Any)
36353635
dtype('float64')
36363636
"""
36373637
if dtype is None:
3638-
dtype = blosc2.float64
3638+
dtype = blosc2.DEFAULT_FLOAT
36393639
return full(shape, 1, dtype, **kwargs)
36403640

36413641

@@ -3655,11 +3655,11 @@ def arange(
36553655
36563656
Parameters
36573657
----------
3658-
start: int, float, complex or np.number
3658+
start: int, float
36593659
The starting value of the sequence.
3660-
stop: int, float, complex or np.number
3660+
stop: int, float
36613661
The end value of the sequence.
3662-
step: int, float, complex or np.number
3662+
step: int, float or None
36633663
Spacing between values.
36643664
dtype: np.dtype or list str
36653665
The data type of the array elements in NumPy format. Default is
@@ -3717,9 +3717,9 @@ def arange_fill(inputs, output, offset):
37173717
raise ValueError("The shape is not consistent with the start, stop and step values")
37183718
if dtype is None:
37193719
dtype = (
3720-
blosc2.float64
3720+
blosc2.DEFAULT_FLOAT
37213721
if np.any([np.issubdtype(type(d), float) for d in (start, stop, step)])
3722-
else blosc2.int64
3722+
else blosc2.DEFAULT_INT
37233723
)
37243724
dtype = _check_dtype(dtype)
37253725

@@ -3742,7 +3742,14 @@ def arange_fill(inputs, output, offset):
37423742

37433743
# Define a numpy linspace-like function
37443744
def linspace(
3745-
start, stop, num=None, endpoint=True, dtype=np.float64, shape=None, c_order=True, **kwargs: Any
3745+
start: int | float | complex,
3746+
stop: int | float | complex,
3747+
num: int | None = None,
3748+
dtype=None,
3749+
endpoint: bool = True,
3750+
shape=None,
3751+
c_order: bool = True,
3752+
**kwargs: Any,
37463753
) -> NDArray:
37473754
"""Return evenly spaced numbers over a specified interval.
37483755
@@ -3752,24 +3759,28 @@ def linspace(
37523759
37533760
Parameters
37543761
----------
3755-
start: int, float, complex or np.number
3762+
start: int, float, complex
37563763
The starting value of the sequence.
3757-
stop: int, float, complex or np.number
3764+
stop: int, float, complex
37583765
The end value of the sequence.
3759-
num: int
3760-
Number of samples to generate.
3766+
num: int | None
3767+
Number of samples to generate. Default None.
3768+
dtype: np.dtype or list str
3769+
The data type of the array elements in NumPy format. If None, inferred from
3770+
start, stop, step. Default is None.
37613771
endpoint: bool
37623772
If True, `stop` is the last sample. Otherwise, it is not included.
3763-
dtype: np.dtype or list str
3764-
The data type of the array elements in NumPy format. Default is `np.float64`.
37653773
shape: int, tuple or list
37663774
The shape of the final array. If None, the shape will be guessed from `num`.
37673775
c_order: bool
37683776
Whether to store the array in C order (row-major) or insertion order.
37693777
Insertion order means that values will be stored in the array
37703778
following the order of chunks in the array; this is more memory
37713779
efficient, as it does not require an intermediate copy of the array.
3772-
Default is C order.
3780+
Default is True.
3781+
**kwargs: Any
3782+
Keyword arguments accepted by the :func:`empty` constructor.
3783+
37733784
37743785
Returns
37753786
-------
@@ -3779,33 +3790,48 @@ def linspace(
37793790

37803791
def linspace_fill(inputs, output, offset):
37813792
lout = len(output)
3782-
start, stop, num = inputs
3793+
start, stop, num, endpoint = inputs
3794+
# if num = 1 do nothing
3795+
step = (stop - start) / (num - 1) if endpoint and num > 1 else (stop - start) / num
37833796
# Compute proper start and stop values for the current chunk
3784-
start_ = start + offset[0] / num * (stop - start)
3785-
stop_ = start_ + lout / num * (stop - start)
3786-
output[:] = np.linspace(start_, stop_, lout, endpoint=False, dtype=output.dtype)
3787-
3788-
if shape is None or num is None:
3789-
if shape is None and num is None:
3790-
raise ValueError("Either `shape` or `num` must be specified.")
3791-
if shape is None: # num is not None
3792-
shape = (num,)
3793-
else: # num is none
3794-
num = math.prod(shape)
3797+
# except for 0th iter, have already included start_ in prev iter
3798+
start_ = start + offset[0] * step if offset[0] == 0 else start + (offset[0] + 1) * step
3799+
stop_ = start_ + lout * step
3800+
if offset[0] + lout == num: # reached end
3801+
output[:] = np.linspace(start_, stop, lout, endpoint=endpoint, dtype=output.dtype)
3802+
else: # always include start and stop
3803+
output[:] = np.linspace(start_, stop_, lout, endpoint=True, dtype=output.dtype)
3804+
3805+
if num < 0:
3806+
raise ValueError("num must be nonnegative.")
3807+
3808+
if shape is None and num is None:
3809+
raise ValueError("Either `shape` or `num` must be specified.")
3810+
if shape is None: # num is not None
3811+
shape = (num,)
3812+
else: # num is none
3813+
num = math.prod(shape)
37953814

37963815
# check compatibility of shape and num
37973816
if math.prod(shape) != num:
37983817
raise ValueError("The specified shape is not consistent with the specified num value")
3818+
3819+
if dtype is None:
3820+
dtype = (
3821+
blosc2.DEFAULT_COMPLEX
3822+
if np.any([np.issubdtype(type(d), complex) for d in (start, stop)])
3823+
else blosc2.DEFAULT_FLOAT
3824+
)
3825+
37993826
dtype = _check_dtype(dtype)
38003827

3801-
if is_inside_new_expr():
3828+
if is_inside_new_expr() or num == 0:
38023829
# We already have the dtype and shape, so return immediately
3803-
return blosc2.zeros(shape, dtype=dtype)
3830+
return blosc2.zeros(shape, dtype=dtype) # will return empty array for num == 0
38043831

38053832
lshape = (math.prod(shape),)
3806-
if endpoint:
3807-
stop += (stop - start) / (num - 1)
3808-
inputs = (start, stop, num)
3833+
3834+
inputs = (start, stop, num, endpoint)
38093835
lazyarr = blosc2.lazyudf(linspace_fill, inputs, dtype=dtype, shape=lshape)
38103836
if len(shape) == 1:
38113837
# C order is guaranteed, and no reshape is needed

0 commit comments

Comments
 (0)