Skip to content

Commit 5913fba

Browse files
Making the blosc2 engine respect the default pandas udf signatures
1 parent b334981 commit 5913fba

File tree

2 files changed

+59
-50
lines changed

2 files changed

+59
-50
lines changed

src/blosc2/proxy.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,7 @@ def map(cls, data, func, args, kwargs, decorator, skip_na):
699699
with the NumPy array as the function parameter, instead of calling the
700700
function once for each element.
701701
"""
702-
data = cls._ensure_numpy_data(data)
703-
func = decorator(func)
704-
return func(data, *args, **kwargs)
702+
raise NotImplementedError("The Blosc2 engine does not support map. Use apply instead.")
705703

706704
@classmethod
707705
def apply(cls, data, func, args, kwargs, decorator, axis):
@@ -713,7 +711,23 @@ def apply(cls, data, func, args, kwargs, decorator, axis):
713711
"""
714712
data = cls._ensure_numpy_data(data)
715713
func = decorator(func)
716-
return func(data, *args, **kwargs)
714+
if data.ndim == 1 or axis is None:
715+
# pandas Series.apply or pipe
716+
return func(data, *args, **kwargs)
717+
elif axis in (0, "index"):
718+
# pandas apply(axis=0) column-wise
719+
result = []
720+
for row_idx in range(data.shape[1]):
721+
result.append(func(data[:, row_idx], *args, **kwargs))
722+
return np.vstack(result).transpose()
723+
elif axis == (1, "columns"):
724+
# pandas apply(axis=1) row-wise
725+
result = []
726+
for col_idx in range(data.shape[0]):
727+
result.append(func(data[col_idx, :], *args, **kwargs))
728+
return np.vstack(result)
729+
else:
730+
raise NotImplementedError(f"Unknown axis '{axis}'. Use one of 0, 1 or None.")
717731

718732

719733
jit.__pandas_udf__ = PandasUdfEngine

tests/test_pandas_udf_engine.py

Lines changed: 41 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,72 +7,88 @@
77
#######################################################################
88

99
import numpy as np
10+
import pytest
1011

1112
import blosc2
1213

1314

1415
class TestPandasUDF:
15-
def test_map_1d(self):
16+
def test_map(self):
1617
def add_one(x):
1718
return x + 1
1819

1920
data = np.array([1, 2])
2021

21-
result = blosc2.jit.__pandas_udf__.map(
22+
with pytest.raises(NotImplementedError):
23+
blosc2.jit.__pandas_udf__.map(
24+
data,
25+
add_one,
26+
args=(),
27+
kwargs={},
28+
decorator=blosc2.jit,
29+
skip_na=False,
30+
)
31+
32+
def test_apply_1d(self):
33+
def add_one(x):
34+
return x + 1
35+
36+
data = np.array([1, 2])
37+
38+
result = blosc2.jit.__pandas_udf__.apply(
2239
data,
2340
add_one,
2441
args=(),
2542
kwargs={},
2643
decorator=blosc2.jit,
27-
skip_na=False,
44+
axis=0,
2845
)
2946
assert result.shape == (2,)
3047
assert result[0] == 2
3148
assert result[1] == 3
3249

33-
def test_map_1d_with_args(self):
50+
def test_apply_1d_with_args(self):
3451
def add_numbers(x, num1, num2):
3552
return x + num1 + num2
3653

3754
data = np.array([1, 2])
3855

39-
result = blosc2.jit.__pandas_udf__.map(
56+
result = blosc2.jit.__pandas_udf__.apply(
4057
data,
4158
add_numbers,
4259
args=(10,),
4360
kwargs={"num2": 100},
4461
decorator=blosc2.jit,
45-
skip_na=False,
62+
axis=0,
4663
)
4764
assert result.shape == (2,)
4865
assert result[0] == 111
4966
assert result[1] == 112
5067

51-
def test_map_2d(self):
68+
def test_apply_2d(self):
5269
def add_one(x):
70+
assert x.shape == (2, 3)
5371
return x + 1
5472

55-
data = np.array([[1, 2], [3, 4]])
73+
data = np.array([[1, 2, 3], [4, 5, 6]])
5674

57-
result = blosc2.jit.__pandas_udf__.map(
75+
result = blosc2.jit.__pandas_udf__.apply(
5876
data,
5977
add_one,
6078
args=(),
6179
kwargs={},
6280
decorator=blosc2.jit,
63-
skip_na=False,
81+
axis=None,
6482
)
65-
assert result.shape == (2, 2)
66-
assert result[0, 0] == 2
67-
assert result[0, 1] == 3
68-
assert result[1, 0] == 4
69-
assert result[1, 1] == 5
83+
expected = np.array([[2, 3, 4], [5, 6, 7]])
84+
assert np.array_equal(result, expected)
7085

71-
def test_apply_1d(self):
86+
def test_apply_2d_by_column(self):
7287
def add_one(x):
88+
assert x.shape == (2,)
7389
return x + 1
7490

75-
data = np.array([1, 2])
91+
data = np.array([[1, 2, 3], [4, 5, 6]])
7692

7793
result = blosc2.jit.__pandas_udf__.apply(
7894
data,
@@ -82,44 +98,23 @@ def add_one(x):
8298
decorator=blosc2.jit,
8399
axis=0,
84100
)
85-
assert result.shape == (2,)
86-
assert result[0] == 2
87-
assert result[1] == 3
101+
expected = np.array([[2, 3, 4], [5, 6, 7]])
102+
assert np.array_equal(result, expected)
88103

89-
def test_apply_1d_with_args(self):
90-
def add_numbers(x, num1, num2):
91-
return x + num1 + num2
92-
93-
data = np.array([1, 2])
94-
95-
result = blosc2.jit.__pandas_udf__.apply(
96-
data,
97-
add_numbers,
98-
args=(10,),
99-
kwargs={"num2": 100},
100-
decorator=blosc2.jit,
101-
axis=0,
102-
)
103-
assert result.shape == (2,)
104-
assert result[0] == 111
105-
assert result[1] == 112
106-
107-
def test_apply_2d(self):
104+
def test_apply_2d_by_row(self):
108105
def add_one(x):
106+
assert x.shape == (3,)
109107
return x + 1
110108

111-
data = np.array([[1, 2], [3, 4]])
109+
data = np.array([[1, 2, 3], [4, 5, 6]])
112110

113111
result = blosc2.jit.__pandas_udf__.apply(
114112
data,
115113
add_one,
116114
args=(),
117115
kwargs={},
118116
decorator=blosc2.jit,
119-
axis=0,
117+
axis=1,
120118
)
121-
assert result.shape == (2, 2)
122-
assert result[0, 0] == 2
123-
assert result[0, 1] == 3
124-
assert result[1, 0] == 4
125-
assert result[1, 1] == 5
119+
expected = np.array([[2, 3, 4], [5, 6, 7]])
120+
assert np.array_equal(result, expected)

0 commit comments

Comments
 (0)