Skip to content

Commit e7262ad

Browse files
committed
Solve issue #503
1 parent dbb8532 commit e7262ad

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

src/blosc2/lazyexpr.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,8 +2713,9 @@ def find_args(expr):
27132713

27142714
return value, expression[idx:idx2]
27152715

2716-
def _compute_expr(self, item, kwargs): # noqa : C901
27172716
# ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
2717+
2718+
def _compute_expr(self, item, kwargs): # noqa : C901
27182719
# that are implemenetd in python-blosc2 not in numexpr
27192720
global safe_blosc2_globals
27202721
if len(safe_blosc2_globals) == 0:
@@ -2748,10 +2749,11 @@ def _compute_expr(self, item, kwargs): # noqa : C901
27482749
where_x = self._where_args["_where_x"]
27492750
where_y = self._where_args["_where_y"]
27502751
return np.where(lazy_expr, where_x, where_y)[key]
2751-
if hasattr(self, "_output"):
2752+
out = kwargs.get("_output", None)
2753+
if out is not None:
27522754
# This is not exactly optimized, but it works for now
2753-
self._output[:] = lazy_expr[key]
2754-
return self._output
2755+
out[:] = lazy_expr[key]
2756+
return out
27552757
arr = lazy_expr[key]
27562758
if builtins.sum(mask) > 0:
27572759
# Correct shape to adjust to NumPy convention
@@ -2820,11 +2822,11 @@ def sort(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
28202822

28212823
def compute(self, item=(), **kwargs) -> blosc2.NDArray:
28222824
# When NumPy ufuncs are called, the user may add an `out` parameter to kwargs
2823-
if "out" in kwargs:
2825+
if "out" in kwargs: # use provided out preferentially
28242826
kwargs["_output"] = kwargs.pop("out")
2825-
self._output = kwargs["_output"]
2826-
if hasattr(self, "_output"):
2827+
elif hasattr(self, "_output"):
28272828
kwargs["_output"] = self._output
2829+
28282830
if "ne_args" in kwargs:
28292831
kwargs["_ne_args"] = kwargs.pop("ne_args")
28302832
if hasattr(self, "_ne_args"):

tests/ndarray/test_lazyexpr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,8 +1731,18 @@ def test_lazylinalg():
17311731

17321732
# Test for issue #503 (LazyArray.compute() should honor out param)
17331733
def test_lazyexpr_compute_out():
1734+
# check reductions
17341735
a = blosc2.ones(10)
17351736
out = blosc2.zeros(1)
17361737
lexpr = blosc2.lazyexpr("sum(a)")
17371738
assert lexpr.compute(out=out) is out
17381739
assert out[0] == 10
1740+
assert lexpr.compute() is not out
1741+
1742+
# check normal expression
1743+
a = blosc2.ones(10)
1744+
out = blosc2.zeros(10)
1745+
lexpr = blosc2.lazyexpr("sin(a)")
1746+
assert lexpr.compute(out=out) is out
1747+
assert out[0] == np.sin(1)
1748+
assert lexpr.compute() is not out

0 commit comments

Comments
 (0)