-
Notifications
You must be signed in to change notification settings - Fork 355
Open
Description
SINDy._process_trajectories does not take x_dot into account if it is provided and instead uses the calculated x_dot.
In the following example, x_dot and x_dot_processed must be equal, but they are not.
Reproducing code example:
import pysindy as ps
import numpy as np
from scipy.integrate import solve_ivp
from pysindy.utils import lorenz
from pysindy.pysindy import _check_multiple_trajectories, _adapt_to_multiple_trajectories, _comprehend_and_validate_inputs
integrator_keywords = {}
integrator_keywords['rtol'] = 1e-12
integrator_keywords['method'] = 'LSODA'
integrator_keywords['atol'] = 1e-12
dt = 0.001
t_train = np.arange(0, 100, dt)
t_train_span = (t_train[0], t_train[-1])
x0_train = [-8, 8, 27]
x_train = solve_ivp(lorenz, t_train_span,
x0_train, t_eval=t_train, **integrator_keywords).y.T
x_dot_train = np.array(
[lorenz(0, x_train[i]) for i in range(t_train.size)]
u = None
model = ps.SINDy(
optimizer=ps.STLSQ(),
feature_library=ps.PolynomialLibrary(),
)
if not _check_multiple_trajectories(x_train, x_dot_train, u):
x, t, x_dot, u = _adapt_to_multiple_trajectories(x_train, dt, x_dot_train, u)
x, x_dot, u = _comprehend_and_validate_inputs(
x, t, x_dot, u, feature_library=ps.PolynomialLibrary()
)
x_processed, x_dot_processed = model._process_trajectories(x, t, x_dot)
np.testing.assert_array_equal(x_dot, x_dot_processed)Error message:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[6], [line 12](vscode-notebook-cell:?execution_count=6&line=12)
8 x, x_dot, u = _comprehend_and_validate_inputs(
9 x, t, x_dot, u, feature_library=ps.PolynomialLibrary()
10 )
11 x_processed, x_dot_processed = model._process_trajectories(x, t, x_dot)
---> [12](vscode-notebook-cell:?execution_count=6&line=12) np.testing.assert_array_equal(x_dot, x_dot_processed)
[... skipping hidden 1 frame]
File ~/pysindy/env/lib/python3.12/site-packages/numpy/testing/_private/utils.py:926, in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf, strict, names)
921 err_msg += '\n' + '\n'.join(remarks)
922 msg = build_err_msg([ox, oy], err_msg,
923 verbose=verbose, header=header,
924 names=names,
925 precision=precision)
--> [926](https://vscode-remote+ssh-002dremote-002b172-002e28-002e14-002e187.vscode-resource.vscode-cdn.net/home/yash6599/pysindy/~/pysindy/env/lib/python3.12/site-packages/numpy/testing/_private/utils.py:926) raise AssertionError(msg)
927 except ValueError:
928 import traceback
AssertionError:
Arrays are not equal
Mismatched elements: 300000 / 300000 (100%)
...
[ 158.24432 , -16.883872, -134.237849],
[ 156.49759 , -17.712639, -132.492892],...
DESIRED: array([[[ 159.997055, -16.018512, -135.994382],
[ 158.245812, -16.874688, -134.24073 ],
[ 156.499157, -17.703738, -132.495873],...
### PySINDy/Python version information:
2.0.1.dev1+g9df5bccc9.d20251007
Metadata
Metadata
Assignees
Labels
No labels