Skip to content

Commit e6c4753

Browse files
committed
feat: improve jit a bit and add perf test
1 parent 95d74c3 commit e6c4753

File tree

3 files changed

+1023
-16
lines changed

3 files changed

+1023
-16
lines changed

PyEMD/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
__version__ = "1.8.0"
3+
__version__ = "1.8.1"
44
logger = logging.getLogger("pyemd")
55

66
from PyEMD.CEEMDAN import CEEMDAN # noqa

PyEMD/experimental/jitemd.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def __call__(self, s, t, max_imf=-1):
8383

8484
@nb.jit(float64[:](float64[:], int64, float64[:]), nopython=True)
8585
def np_round(x, decimals, out):
86-
return np.round_(x, decimals, out)
86+
out[:] = np.round(x, decimals)
87+
return out
8788

8889

8990
@nb.njit
@@ -216,9 +217,11 @@ def _find_extrema_simple(T: np.ndarray, S: np.ndarray) -> FindExtremaOutput:
216217
indmin = np.append(indmin, np.array(imin)).astype(np.int64)
217218
indmin.sort()
218219

219-
local_max_pos = T[indmax].astype(S.dtype)
220+
# Return indices as float64 (for _prepare_points_simple which expects indices)
221+
# and the values at those indices
222+
local_max_pos = indmax.astype(S.dtype) # indices, not T values
220223
local_max_val = S[indmax].astype(S.dtype)
221-
local_min_pos = T[indmin].astype(S.dtype)
224+
local_min_pos = indmin.astype(S.dtype) # indices, not T values
222225
local_min_val = S[indmin].astype(S.dtype)
223226

224227
return local_max_pos, local_max_val, local_min_pos, local_min_val, indzer.astype(S.dtype)
@@ -594,18 +597,19 @@ def _prepare_points_simple(
594597
max_extrema = np.vstack((tmax, zmax))
595598
min_extrema = np.vstack((tmin, zmin))
596599

597-
# For posterity:
598-
# I tried with np.delete and np.vstack([ ]) but both didn't work.
599-
# np.delete works only with 2 args, and vstack had problem with list comphr.
600-
max_dup_idx = np.where(max_extrema[0, 1:] == max_extrema[0, :-1])[0]
601-
if len(max_dup_idx):
602-
for col_idx in max_dup_idx:
603-
max_extrema = np.hstack((max_extrema[:, :col_idx], max_extrema[:, col_idx + 1 :]))
604-
605-
min_dup_idx = np.where(min_extrema[0, 1:] == min_extrema[0, :-1])[0]
606-
if len(min_dup_idx):
607-
for col_idx in min_dup_idx:
608-
min_extrema = np.hstack((min_extrema[:, :col_idx], min_extrema[:, col_idx + 1 :]))
600+
# Remove duplicates - keep only unique x positions
601+
# Use a mask-based approach to handle index shifting correctly
602+
max_unique_mask = np.ones(max_extrema.shape[1], dtype=np.bool_)
603+
for i in range(1, max_extrema.shape[1]):
604+
if max_extrema[0, i] == max_extrema[0, i - 1]:
605+
max_unique_mask[i] = False
606+
max_extrema = max_extrema[:, max_unique_mask]
607+
608+
min_unique_mask = np.ones(min_extrema.shape[1], dtype=np.bool_)
609+
for i in range(1, min_extrema.shape[1]):
610+
if min_extrema[0, i] == min_extrema[0, i - 1]:
611+
min_unique_mask[i] = False
612+
min_extrema = min_extrema[:, min_unique_mask]
609613

610614
return max_extrema, min_extrema
611615

0 commit comments

Comments
 (0)