Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions nitime/_compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
"""Compatibility utilities for different dependency versions."""

import numpy as np
from packaging.version import Version


def _reshape_view(arr, shape):
"""Reshape an array as a view, raising if a copy would be required.

This function provides compatibility across NumPy versions for reshaping
arrays as views. On NumPy >= 2.1, it uses ``reshape(copy=False)`` which
explicitly fails if a view cannot be created. On older versions, it uses
direct shape assignment which has the same behavior but is deprecated in
NumPy 2.5+.

Parameters
----------
arr : ndarray
The array to reshape.
shape : tuple of int
The new shape.

Returns
-------
ndarray
A reshaped view of the array.

Raises
------
AttributeError
If a view cannot be created on NumPy < 2.1.
ValueError
If a view cannot be created on NumPy >= 2.1.
"""
if Version(np.__version__) >= Version("2.1"):
return arr.reshape(shape, copy=False)
else:
arr.shape = shape
return arr


# np.trapezoid was introduced and np.trapz deprecated in numpy 2.0
try: # NP2
if Version(np.__version__) >= Version("2.0"):
from numpy import trapezoid
except ImportError: # NP1
else:
from numpy import trapz as trapezoid
16 changes: 9 additions & 7 deletions nitime/algorithms/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
# To support older versions of numpy that don't have tril_indices:
from nitime.index_utils import tril_indices, triu_indices

from nitime._compat import _reshape_view


# Set global variables for the default NFFT to be used in spectral analysis and
# the overlap:
Expand Down Expand Up @@ -308,7 +310,7 @@ def periodogram_csd(s, Fs=2 * np.pi, Sk=None, NFFT=None, sides='default',

"""
s_shape = s.shape
s.shape = (-1, s_shape[-1])
s = _reshape_view(s, (-1, s_shape[-1]))
# defining an Sk_loc is a little opaque, but it avoids having to
# reset the shape of any user-given Sk later on
if Sk is not None:
Expand All @@ -322,7 +324,7 @@ def periodogram_csd(s, Fs=2 * np.pi, Sk=None, NFFT=None, sides='default',
N = s.shape[-1]
Sk_loc = fftpack.fft(s, n=N)
# reset s.shape
s.shape = s_shape
s = _reshape_view(s, s_shape)

M = Sk_loc.shape[0]

Expand Down Expand Up @@ -550,7 +552,7 @@ def multi_taper_psd(
NFFT = spectra.shape[-1]
K = len(eigvals)
# collapse spectra's shape back down to 3 dimensions
spectra.shape = (M, K, NFFT)
spectra = _reshape_view(spectra, (M, K, NFFT))

last_freq = NFFT // 2 + 1 if sides == 'onesided' else NFFT

Expand Down Expand Up @@ -593,12 +595,12 @@ def multi_taper_psd(
freqs = np.linspace(0, Fs, NFFT, endpoint=False)

out_shape = s.shape[:-1] + (len(freqs),)
sdf_est.shape = out_shape
sdf_est = _reshape_view(sdf_est, out_shape)
if jackknife:
jk_var.shape = out_shape
jk_var = _reshape_view(jk_var, out_shape)
return freqs, sdf_est, jk_var
else:
nu.shape = out_shape
nu = _reshape_view(nu, out_shape)
return freqs, sdf_est, nu


Expand Down Expand Up @@ -690,7 +692,7 @@ def multi_taper_csd(s, Fs=2 * np.pi, NW=None, BW=None, low_bias=True,
NFFT = spectra.shape[-1]
K = len(eigvals)
# collapse spectra's shape back down to 3 dimensions
spectra.shape = (M, K, NFFT)
spectra = _reshape_view(spectra, (M, K, NFFT))

# compute the cross-spectral density functions
last_freq = NFFT // 2 + 1 if sides == 'onesided' else NFFT
Expand Down
2 changes: 1 addition & 1 deletion nitime/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def test_timearray_math_functions(f, tu):
"Calling TimeArray.min() .max(), mean() should return TimeArrays"
a = np.arange(2, 11)
b = ts.TimeArray(a, time_unit=tu)
if f == "ptp" and ts._NP_2:
if f == "ptp":
want = np.ptp(a)
else:
want = getattr(a, f)()
Expand Down
10 changes: 1 addition & 9 deletions nitime/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@
# Our own
from nitime import descriptors as desc

try:
_NP_2 = int(np.__version__.split(".")[0]) >= 2
except Exception:
_NP_2 = True

#-----------------------------------------------------------------------------
# Module globals
Expand Down Expand Up @@ -319,11 +315,7 @@ def mean(self, *args, **kwargs):
return ret

def ptp(self, *args, **kwargs):
if _NP_2:
ptp = np.ptp
else:
ptp = np.ndarray.ptp
ret = TimeArray(ptp(self, *args, **kwargs),
ret = TimeArray(np.ptp(self, *args, **kwargs),
time_unit=base_unit)
ret.convert_unit(self.time_unit)
return ret
Expand Down
6 changes: 4 additions & 2 deletions nitime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from nitime.lazy import scipy_stats_distributions as dists
from nitime.lazy import scipy_interpolate as interpolate

from nitime._compat import _reshape_view


#-----------------------------------------------------------------------------
# Spectral estimation testing utilities
Expand Down Expand Up @@ -756,7 +758,7 @@ def tapered_spectra(s, tapers, NFFT=None, low_bias=True):
# compute the y_{i,k}(f) -- full FFT takes ~1.5x longer, but unpacking
# results of real-valued FFT eats up memory
t_spectra = fftpack.fft(tapered, n=NFFT, axis=-1)
t_spectra.shape = rest_of_dims + (K, NFFT)
t_spectra = _reshape_view(t_spectra, rest_of_dims + (K, NFFT))
if eigvals is None:
return t_spectra
return t_spectra, eigvals
Expand Down Expand Up @@ -834,7 +836,7 @@ def detect_lines(s, tapers, p=None, **taper_kws):
numr[...,0] = 1; # don't care about DC
# denominator -- strength of residual
spectra = np.rollaxis(spectra, -2, 0)
U0.shape = (K,) + (1,) * (spectra.ndim-1)
U0 = _reshape_view(U0, (K,) + (1,) * (spectra.ndim-1))
denomr = spectra - U0*mu
denomr = np.sum(np.abs(denomr)**2, axis=0) / (2*K-2)
denomr[...,0] = 1;
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ classifiers = [
dependencies = [
"matplotlib>=3.7",
"numpy>=1.24",
"packaging",
"scipy>=1.10",
]

Expand Down
Loading