Skip to content

Commit 1ccdfce

Browse files
committed
Follow NumPy rules for scalar-array operations
1 parent 1f0f33f commit 1ccdfce

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

examples/ndarray/jit-numpy-funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Create some sample data
1919
a = blosc2.linspace(0, 1, 10 * 100, dtype="float32", shape=(10, 100))
2020
b = blosc2.linspace(1, 2, 10 * 100, dtype="float32", shape=(10, 100))
21-
c = blosc2.linspace(-10, 10, 10, dtype="float32", shape=(100,))
21+
c = blosc2.linspace(-10, 10, 100, dtype="float32", shape=(100,))
2222

2323

2424
# Example 1: Basic usage of the jit decorator with reduction

src/blosc2/lazyexpr.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,9 +1903,18 @@ def infer_dtype(op, value1, value2):
19031903
if op != "~" and value2.dtype != np.bool_:
19041904
raise ValueError(f"Invalid operand type for {op}: {value2.dtype}")
19051905
return np.dtype(np.bool_)
1906-
dtype1 = value1.dtype if hasattr(value1, "dtype") else np.array(value1).dtype
1907-
dtype2 = value2.dtype if hasattr(value2, "dtype") else np.array(value2).dtype
1908-
return np.result_type(dtype1, dtype2)
1906+
1907+
# Follow NumPy rules for scalar-array operations
1908+
# Create small arrays with the same dtypes and let NumPy's type promotion determine the result type
1909+
if np.isscalar(value1) and hasattr(value2, "shape"):
1910+
arr2 = np.array([0], dtype=value2.dtype)
1911+
return (value1 + arr2).dtype
1912+
elif np.isscalar(value2) and hasattr(value1, "shape"):
1913+
arr1 = np.array([0], dtype=value1.dtype)
1914+
return (arr1 + value2).dtype
1915+
else:
1916+
# Both are arrays or both are scalars, use NumPy's type promotion rules
1917+
return np.result_type(value1, value2)
19091918

19101919

19111920
class LazyExpr(LazyArray):

tests/ndarray/test_lazyexpr.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,3 +1402,24 @@ def test_chain_persistentexpressions():
14021402
le4_.save("expr4.b2nd", mode="w")
14031403
myle4 = blosc2.open("expr4.b2nd")
14041404
assert (myle4[:] == le4[:]).all()
1405+
1406+
1407+
@pytest.mark.parametrize(
1408+
"values",
1409+
[
1410+
(np.ones(10, dtype=np.uint16), 2),
1411+
(np.ones(10, dtype=np.uint16), np.uint32(2)),
1412+
(2, np.ones(10, dtype=np.uint16)),
1413+
(np.uint32(2), np.ones(10, dtype=np.uint16)),
1414+
(np.ones(10, dtype=np.uint16), 2.0),
1415+
(np.ones(10, dtype=np.float32), 2.0),
1416+
(np.ones(10, dtype=np.float32), 2.0j),
1417+
],
1418+
)
1419+
def test_scalar_dtypes(values):
1420+
value1, value2 = values
1421+
dtype1 = (value1 + value2).dtype
1422+
avalue1 = blosc2.asarray(value1) if hasattr(value1, "shape") else value1
1423+
avalue2 = blosc2.asarray(value2) if hasattr(value2, "shape") else value2
1424+
dtype2 = (avalue1 * avalue2).dtype
1425+
assert dtype1 == dtype2, f"Expected {dtype1} but got {dtype2}"

0 commit comments

Comments
 (0)