Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions derivative/dglobal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
from scipy import interpolate, sparse
from scipy.special import legendre
from sklearn.linear_model import Lasso
from specderiv import cheb_deriv, fourier_deriv


@register("spectral")
class Spectral(Derivative):
def __init__(self, **kwargs):
def __init__(self, order=1, axis=0, basis='chebyshev', **kwargs):
"""
Compute the numerical derivative by first computing the FFT. In Fourier space, derivatives are multiplication
by i*phase; compute the IFFT after.

Args:
order (int): order of the derivative, defaults to 1st order
axis (int): the dimension of the data along which to differentiate, defaults to first dimension
basis (str): "chebyshev" or "fourier", the set of basis functions to use for differentiation
**kwargs: Optional keyword arguments.

Keyword Args:
Expand All @@ -28,19 +32,22 @@ def __init__(self, **kwargs):
"""
# Filter function. Default: Identity filter
self.filter = kwargs.get('filter', np.vectorize(lambda f: 1))
self._x_hat = None
self._freq = None
self.order = order
if basis not in ['chebyshev', 'fourier']:
raise ValueError("Only chebyshev and fourier bases are allowed.")
self.basis = basis

def _dglobal(self, t, x):
self._x_hat = np.fft.fft(x)
self._freq = np.fft.fftfreq(t.size, d=(t[1] - t[0]))
def _global(self, t, x):
if self.basis == 'chebyshev':
return cheb_deriv(x, t, self.order, self.axis)
else: # self.basis == 'fourier'
return fourier_deriv(x, t, self.order, self.axis)

def compute(self, t, x, i):
return next(self.compute_for(t, x, [i]))

def compute_for(self, t, x, indices):
self._dglobal(t, x)
res = np.fft.ifft(1j * 2 * np.pi * self._freq * self.filter(self._freq) * self._x_hat).real
res = self._global(t, x) # cached
for i in indices:
yield res[i]

Expand Down Expand Up @@ -212,7 +219,6 @@ def __init__(self, alpha=None):
"""
self.alpha = alpha


@_memoize_arrays(1)
def _global(self, t, z, alpha):
if alpha is None:
Expand Down
2 changes: 1 addition & 1 deletion derivative/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def compute(self, t, x, i):
"""
Compute the derivative of one-dimensional data x with respect to t at the index i of x, (dx/dt)[i].

Computation of a derivative should fail explicitely if the implementation is unable to compute a derivative at
Computation of a derivative should fail explicitly if the implementation is unable to compute a derivative at
the desired index. Used for global differentiation methods, for example.

This requires that x and t have equal lengths >= 2, and that the index i is a valid index.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ numpy = ">=1.18.3"
scipy = "^1.4.1"
scikit-learn = "^1"
importlib-metadata = ">=7.1.0"
spectral-derivatives = ">=0.6"

# docs
sphinx = {version = "7.2.6", optional = true}
Expand Down
29 changes: 15 additions & 14 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from derivative.differentiation import _gen_method


# Utilities for tests
# ===================
def default_args(kind):
""" The assumption is that the function will have dt = 1/100 over a range of 1 and not vary much. The goal is to
to set the parameters such that we obtain effective derivatives under these conditions.
Expand All @@ -26,8 +28,7 @@ def default_args(kind):
return {"sigma": 1, "lmbd": .01, "kernel": "gaussian"}
else:
raise ValueError('Unimplemented default args for kind {}.'.format(kind))



class NumericalExperiment:
def __init__(self, fn, fn_str, t, kind, args):
self.fn = fn
Expand All @@ -40,7 +41,6 @@ def __init__(self, fn, fn_str, t, kind, args):
def run(self):
return dxdt(self.fn(self.t), self.t, self.kind, self.axis, **self.kwargs)


def compare(experiment, truth, rel_tol, abs_tol, shape_only=False):
""" Compare a numerical experiment to theoretical expectations. Issue warnings for derivative methods that fail,
use asserts for implementation requirements.
Expand All @@ -60,8 +60,8 @@ def mean_sq(x):
assert np.linalg.norm(residual, ord=np.inf) < max(abs_tol, np.linalg.norm(truth, ord=np.inf) * rel_tol)


# Check that numbers are returned
# ===============================
# Check that only numbers are returned
# ====================================
@pytest.mark.parametrize("m", methods)
def test_notnan(m):
t = np.linspace(0, 1, 100)
Expand All @@ -71,8 +71,8 @@ def test_notnan(m):
assert not np.any(np.isnan(values)), message


# Test some basic functions
# =========================
# Test that basic functions are differentiated correctly
# ======================================================
funcs_and_derivs = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good library of examples, well thought out.

(lambda t: np.ones_like(t), "f(t) = 1", lambda t: np.zeros_like(t), "const1"),
(lambda t: np.zeros_like(t), "f(t) = 0", lambda t: np.zeros_like(t), "const0"),
Expand Down Expand Up @@ -112,6 +112,8 @@ def test_fn(m, func_spec):
compare(nexp, deriv(t), 1e-1, 1e-1, bad_combo)


# Test smoothing for those that do it
# ===================================
@pytest.mark.parametrize("kind", ("kalman", "trend_filtered"))
def test_smoothing_x(kind):
t = np.linspace(0, 1, 100)
Expand All @@ -122,7 +124,6 @@ def test_smoothing_x(kind):
# MSE
assert np.linalg.norm(x_est - np.sin(t)) ** 2 / len(t) < 1e-1


@pytest.mark.parametrize("kind", ("kalman", "trend_filtered"))
def test_smoothing_functional(kind):
t = np.linspace(0, 1, 100)
Expand All @@ -133,13 +134,14 @@ def test_smoothing_functional(kind):
assert np.linalg.norm(x_est - np.sin(t)) ** 2 / len(t) < 1e-1


# Test caching of the expensive _gen_method using a dummy
# =======================================================
@pytest.fixture
def clean_gen_method_cache():
_gen_method.cache_clear()
yield
_gen_method.cache_clear()


def test_gen_method_caching(clean_gen_method_cache):
x = np.ones(3)
t = np.arange(3)
Expand All @@ -150,7 +152,6 @@ def test_gen_method_caching(clean_gen_method_cache):
assert _gen_method.cache_info().currsize == 1
assert id(expected) == id(result)


def test_gen_method_kwarg_caching(clean_gen_method_cache):
x = np.ones(3)
t = np.arange(3)
Expand All @@ -164,6 +165,8 @@ def test_gen_method_kwarg_caching(clean_gen_method_cache):
assert id(expected) != id(result)


# Test caching of the expensive private _global methods using a dummy
# ===================================================================
@pytest.fixture
def method_inst(request):
x = np.ones(3)
Expand All @@ -173,8 +176,7 @@ def method_inst(request):
yield x, t, method
method._global.cache_clear()


@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered"], indirect=True)
@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered", "spectral"], indirect=True)
def test_dglobal_caching(method_inst):
# make sure we're not recomputing expensive _global() method
x, t, method = method_inst
Expand All @@ -184,8 +186,7 @@ def test_dglobal_caching(method_inst):
assert method._global.cache_info().misses == 1
assert method._global.cache_info().currsize == 1


@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered"], indirect=True)
@pytest.mark.parametrize("method_inst", ["kalman", "trend_filtered", "spectral"], indirect=True)
def test_cached_global_order(method_inst):
x, t, method = method_inst
x = np.vstack((x, -x))
Expand Down