Skip to content

Commit 7a093b4

Browse files
committed
TST: update tests for DiscreteSINDy
1 parent bbc6c2c commit 7a093b4

File tree

3 files changed

+35
-63
lines changed

3 files changed

+35
-63
lines changed

test/test_pysindy.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sklearn.utils.validation import check_is_fitted
2424

2525
from pysindy import pysindy
26-
from pysindy import SINDy
26+
from pysindy import SINDy, DiscreteSINDy
2727
from pysindy.differentiation import SINDyDerivative
2828
from pysindy.differentiation import SmoothedFiniteDifference
2929
from pysindy.feature_library import FourierLibrary
@@ -331,18 +331,18 @@ def test_score_multiple_trajectories(data_multiple_trajectories):
331331
def test_fit_discrete_time(data_discrete_time):
332332
x = data_discrete_time
333333

334-
model = SINDy(discrete_time=True)
334+
model = DiscreteSINDy()
335335
model.fit(x, t=1)
336336
check_is_fitted(model)
337337

338-
model = SINDy(discrete_time=True)
339-
model.fit(x[:-1], x_dot=x[1:], t=1)
338+
model = DiscreteSINDy()
339+
model.fit(x[:-1], x_next=x[1:], t=1)
340340
check_is_fitted(model)
341341

342342

343343
def test_simulate_discrete_time(data_discrete_time):
344344
x = data_discrete_time
345-
model = SINDy(discrete_time=True)
345+
model = DiscreteSINDy()
346346
model.fit(x, t=1)
347347
n_steps = x.shape[0]
348348
x1 = model.simulate(x[0], n_steps)
@@ -359,17 +359,17 @@ def stop_func(xi):
359359

360360
def test_predict_discrete_time(data_discrete_time):
361361
x = data_discrete_time
362-
model = SINDy(discrete_time=True)
362+
model = DiscreteSINDy()
363363
model.fit(x, t=1)
364364
assert len(model.predict(x)) == len(x)
365365

366366

367367
def test_score_discrete_time(data_discrete_time):
368368
x = data_discrete_time
369-
model = SINDy(discrete_time=True)
369+
model = DiscreteSINDy()
370370
model.fit(x, t=1)
371371
assert model.score(x, t=1) > 0.75
372-
assert model.score(x, x_dot=x, t=1) < 1
372+
assert model.score(x, x_next=x, t=1) < 1
373373

374374

375375
def test_bad_multiple_trajectories(data_multiple_trajectories):
@@ -384,20 +384,20 @@ def test_fit_discrete_time_multiple_trajectories(
384384
data_discrete_time_multiple_trajectories,
385385
):
386386
x = data_discrete_time_multiple_trajectories
387-
model = SINDy(discrete_time=True)
387+
model = DiscreteSINDy()
388388
model.fit(x, t=1)
389389
check_is_fitted(model)
390390

391-
model = SINDy(discrete_time=True)
392-
model.fit(x, x_dot=x, t=1)
391+
model = DiscreteSINDy()
392+
model.fit(x, x_next=x, t=1)
393393
check_is_fitted(model)
394394

395395

396396
def test_predict_discrete_time_multiple_trajectories(
397397
data_discrete_time_multiple_trajectories,
398398
):
399399
x = data_discrete_time_multiple_trajectories
400-
model = SINDy(discrete_time=True)
400+
model = DiscreteSINDy()
401401
model.fit(x, t=1)
402402

403403
y = model.predict(x)
@@ -408,14 +408,14 @@ def test_score_discrete_time_multiple_trajectories(
408408
data_discrete_time_multiple_trajectories,
409409
):
410410
x = data_discrete_time_multiple_trajectories
411-
model = SINDy(discrete_time=True)
411+
model = DiscreteSINDy()
412412
model.fit(x, t=1)
413413

414414
s = model.score(x, t=1)
415415
assert s > 0.75
416416

417417
# x is not its own derivative, so we expect bad performance here
418-
s = model.score(x, x_dot=x, t=1)
418+
s = model.score(x, x_next=x, t=1)
419419
assert s < 1
420420

421421

@@ -445,7 +445,7 @@ def test_equations(data, capsys):
445445

446446
def test_print_discrete_time(data_discrete_time, capsys):
447447
x = data_discrete_time
448-
model = SINDy(discrete_time=True)
448+
model = DiscreteSINDy()
449449
model.fit(x, t=1)
450450
model.print()
451451

@@ -455,20 +455,6 @@ def test_print_discrete_time(data_discrete_time, capsys):
455455
assert "(x0)[k+1] = " in out
456456

457457

458-
def test_differentiate(data_lorenz, data_multiple_trajectories):
459-
x, t = data_lorenz
460-
461-
model = SINDy()
462-
model.differentiate(x, t)
463-
464-
x, t = data_multiple_trajectories
465-
model.differentiate(x, t)
466-
467-
model = SINDy(discrete_time=True)
468-
with pytest.raises(RuntimeError):
469-
model.differentiate(x, t)
470-
471-
472458
def test_coefficients_equals_complexity(data_lorenz):
473459
x, t = data_lorenz
474460
model = SINDy()
@@ -486,15 +472,15 @@ def test_simulate_errors(data_lorenz):
486472
with pytest.raises(ValueError):
487473
model.simulate(x[0], t=1)
488474

489-
model = SINDy(discrete_time=True)
475+
model = DiscreteSINDy()
490476
with pytest.raises(ValueError):
491477
model.simulate(x[0], t=[1, 2])
492478

493-
model = SINDy(discrete_time=True)
479+
model = DiscreteSINDy()
494480
with pytest.raises(ValueError):
495481
model.simulate(x[0], t=-1)
496482

497-
model = SINDy(discrete_time=True)
483+
model = DiscreteSINDy()
498484
with pytest.raises(ValueError):
499485
model.simulate(x[0], t=0.5)
500486

test/test_sindyc.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.linear_model import Lasso
99
from sklearn.utils.validation import check_is_fitted
1010

11-
from pysindy import SINDy
11+
from pysindy import SINDy, DiscreteSINDy
1212
from pysindy.optimizers import SR3
1313
from pysindy.optimizers import STLSQ
1414

@@ -257,12 +257,12 @@ def test_score_multiple_trajectories(data_multiple_trajectories):
257257
def test_fit_discrete_time(data):
258258
x, u = data
259259

260-
model = SINDy(discrete_time=True)
260+
model = DiscreteSINDy()
261261
model.fit(x, u=u, t=1)
262262
check_is_fitted(model)
263263

264-
model = SINDy(discrete_time=True)
265-
model.fit(x[:-1], u=u[:-1], x_dot=x[1:], t=1)
264+
model = DiscreteSINDy()
265+
model.fit(x[:-1], u=u[:-1], x_next=x[1:], t=1)
266266
check_is_fitted(model)
267267

268268

@@ -275,7 +275,7 @@ def test_fit_discrete_time(data):
275275
)
276276
def test_simulate_discrete_time(data):
277277
x, u = data
278-
model = SINDy(discrete_time=True)
278+
model = DiscreteSINDy()
279279
model.fit(x, u=u, t=1)
280280
n_steps = x.shape[0]
281281
x1 = model.simulate(x[0], t=n_steps, u=u)
@@ -294,7 +294,7 @@ def test_simulate_discrete_time(data):
294294
)
295295
def test_predict_discrete_time(data):
296296
x, u = data
297-
model = SINDy(discrete_time=True)
297+
model = DiscreteSINDy()
298298
print(x, u)
299299
model.fit(x, u=u, t=1)
300300
assert len(model.predict(x, u=u)) == len(x)
@@ -309,30 +309,30 @@ def test_predict_discrete_time(data):
309309
)
310310
def test_score_discrete_time(data):
311311
x, u = data
312-
model = SINDy(discrete_time=True)
312+
model = DiscreteSINDy()
313313
model.fit(x, u=u, t=1)
314314
assert model.score(x, u=u, t=1) > 0.75
315-
assert model.score(x, u=u, x_dot=x, t=1) < 1
315+
assert model.score(x, u=u, x_next=x, t=1) < 1
316316

317317

318318
def test_fit_discrete_time_multiple_trajectories(
319319
data_discrete_time_multiple_trajectories_c,
320320
):
321321
x, u = data_discrete_time_multiple_trajectories_c
322-
model = SINDy(discrete_time=True)
322+
model = DiscreteSINDy()
323323
model.fit(x, u=u, t=1)
324324
check_is_fitted(model)
325325

326-
model = SINDy(discrete_time=True)
327-
model.fit(x, u=u, x_dot=x, t=1)
326+
model = DiscreteSINDy()
327+
model.fit(x, u=u, x_next=x, t=1)
328328
check_is_fitted(model)
329329

330330

331331
def test_predict_discrete_time_multiple_trajectories(
332332
data_discrete_time_multiple_trajectories_c,
333333
):
334334
x, u = data_discrete_time_multiple_trajectories_c
335-
model = SINDy(discrete_time=True)
335+
model = DiscreteSINDy()
336336
model.fit(x, u=u, t=1)
337337

338338
y = model.predict(x, u=u)
@@ -343,14 +343,14 @@ def test_score_discrete_time_multiple_trajectories(
343343
data_discrete_time_multiple_trajectories_c,
344344
):
345345
x, u = data_discrete_time_multiple_trajectories_c
346-
model = SINDy(discrete_time=True)
346+
model = DiscreteSINDy()
347347
model.fit(x, u=u, t=1)
348348

349349
s = model.score(x, u=u, t=1)
350350
assert s > 0.75
351351

352352
# x is not its own derivative, so we expect bad performance here
353-
s = model.score(x, u=u, x_dot=x, t=1)
353+
s = model.score(x, u=u, x_next=x, t=1)
354354
assert s < 1
355355

356356

@@ -362,7 +362,7 @@ def test_simulate_errors(data_lorenz_c_1d):
362362
with pytest.raises(ValueError):
363363
model.simulate(x[0], t=1, u=u)
364364

365-
model = SINDy(discrete_time=True)
365+
model = DiscreteSINDy()
366366
with pytest.raises(ValueError):
367367
model.simulate(x[0], t=[1, 2], u=u)
368368

@@ -412,7 +412,7 @@ def test_extra_u_warn(data_lorenz_c_1d):
412412

413413
def test_extra_u_warn_discrete(data_discrete_time_c):
414414
x, u = data_discrete_time_c
415-
model = SINDy(discrete_time=True)
415+
model = DiscreteSINDy()
416416
model.fit(x, t=1)
417417

418418
with pytest.warns(UserWarning):
@@ -422,4 +422,4 @@ def test_extra_u_warn_discrete(data_discrete_time_c):
422422
model.score(x, u=u, t=1)
423423

424424
with pytest.warns(UserWarning):
425-
model.simulate(x[0], u=u, t=10, integrator_kws={"rtol": 0.1})
425+
model.simulate(x[0], u=u, t=10)

test/utils/test_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,6 @@ def test_reorder_constraints_2D():
5050
np.testing.assert_array_equal(result, target_order)
5151

5252

53-
def test_validate_controls():
54-
with pytest.raises(ValueError):
55-
validate_control_variables(1, [])
56-
with pytest.raises(ValueError):
57-
validate_control_variables([], 1)
58-
with pytest.raises(ValueError):
59-
validate_control_variables([], [1])
60-
arr = AxesArray(np.ones(4).reshape((2, 2)), axes={"ax_time": 0, "ax_coord": 1})
61-
with pytest.raises(ValueError):
62-
validate_control_variables([arr], [arr[:1]])
63-
u_mod = validate_control_variables([arr], [arr], trim_last_point=True)
64-
assert u_mod[0].n_time == 1
65-
66-
6753
@pytest.mark.parametrize(
6854
["regularization", "lam", "expected"],
6955
[

0 commit comments

Comments
 (0)