2323from sklearn .utils .validation import check_is_fitted
2424
2525from pysindy import pysindy
26- from pysindy import SINDy
26+ from pysindy import SINDy , DiscreteSINDy
2727from pysindy .differentiation import SINDyDerivative
2828from pysindy .differentiation import SmoothedFiniteDifference
2929from pysindy .feature_library import FourierLibrary
@@ -331,18 +331,18 @@ def test_score_multiple_trajectories(data_multiple_trajectories):
331331def test_fit_discrete_time (data_discrete_time ):
332332 x = data_discrete_time
333333
334- model = SINDy ( discrete_time = True )
334+ model = DiscreteSINDy ( )
335335 model .fit (x , t = 1 )
336336 check_is_fitted (model )
337337
338- model = SINDy ( discrete_time = True )
339- model .fit (x [:- 1 ], x_dot = x [1 :], t = 1 )
338+ model = DiscreteSINDy ( )
339+ model .fit (x [:- 1 ], x_next = x [1 :], t = 1 )
340340 check_is_fitted (model )
341341
342342
343343def test_simulate_discrete_time (data_discrete_time ):
344344 x = data_discrete_time
345- model = SINDy ( discrete_time = True )
345+ model = DiscreteSINDy ( )
346346 model .fit (x , t = 1 )
347347 n_steps = x .shape [0 ]
348348 x1 = model .simulate (x [0 ], n_steps )
@@ -359,17 +359,17 @@ def stop_func(xi):
359359
360360def test_predict_discrete_time (data_discrete_time ):
361361 x = data_discrete_time
362- model = SINDy ( discrete_time = True )
362+ model = DiscreteSINDy ( )
363363 model .fit (x , t = 1 )
364364 assert len (model .predict (x )) == len (x )
365365
366366
367367def test_score_discrete_time (data_discrete_time ):
368368 x = data_discrete_time
369- model = SINDy ( discrete_time = True )
369+ model = DiscreteSINDy ( )
370370 model .fit (x , t = 1 )
371371 assert model .score (x , t = 1 ) > 0.75
372- assert model .score (x , x_dot = x , t = 1 ) < 1
372+ assert model .score (x , x_next = x , t = 1 ) < 1
373373
374374
375375def test_bad_multiple_trajectories (data_multiple_trajectories ):
@@ -384,20 +384,20 @@ def test_fit_discrete_time_multiple_trajectories(
384384 data_discrete_time_multiple_trajectories ,
385385):
386386 x = data_discrete_time_multiple_trajectories
387- model = SINDy ( discrete_time = True )
387+ model = DiscreteSINDy ( )
388388 model .fit (x , t = 1 )
389389 check_is_fitted (model )
390390
391- model = SINDy ( discrete_time = True )
392- model .fit (x , x_dot = x , t = 1 )
391+ model = DiscreteSINDy ( )
392+ model .fit (x , x_next = x , t = 1 )
393393 check_is_fitted (model )
394394
395395
396396def test_predict_discrete_time_multiple_trajectories (
397397 data_discrete_time_multiple_trajectories ,
398398):
399399 x = data_discrete_time_multiple_trajectories
400- model = SINDy ( discrete_time = True )
400+ model = DiscreteSINDy ( )
401401 model .fit (x , t = 1 )
402402
403403 y = model .predict (x )
@@ -408,14 +408,14 @@ def test_score_discrete_time_multiple_trajectories(
408408 data_discrete_time_multiple_trajectories ,
409409):
410410 x = data_discrete_time_multiple_trajectories
411- model = SINDy ( discrete_time = True )
411+ model = DiscreteSINDy ( )
412412 model .fit (x , t = 1 )
413413
414414 s = model .score (x , t = 1 )
415415 assert s > 0.75
416416
417417 # x is not its own derivative, so we expect bad performance here
418- s = model .score (x , x_dot = x , t = 1 )
418+ s = model .score (x , x_next = x , t = 1 )
419419 assert s < 1
420420
421421
@@ -445,7 +445,7 @@ def test_equations(data, capsys):
445445
446446def test_print_discrete_time (data_discrete_time , capsys ):
447447 x = data_discrete_time
448- model = SINDy ( discrete_time = True )
448+ model = DiscreteSINDy ( )
449449 model .fit (x , t = 1 )
450450 model .print ()
451451
@@ -455,20 +455,6 @@ def test_print_discrete_time(data_discrete_time, capsys):
455455 assert "(x0)[k+1] = " in out
456456
457457
458- def test_differentiate (data_lorenz , data_multiple_trajectories ):
459- x , t = data_lorenz
460-
461- model = SINDy ()
462- model .differentiate (x , t )
463-
464- x , t = data_multiple_trajectories
465- model .differentiate (x , t )
466-
467- model = SINDy (discrete_time = True )
468- with pytest .raises (RuntimeError ):
469- model .differentiate (x , t )
470-
471-
472458def test_coefficients_equals_complexity (data_lorenz ):
473459 x , t = data_lorenz
474460 model = SINDy ()
@@ -486,15 +472,15 @@ def test_simulate_errors(data_lorenz):
486472 with pytest .raises (ValueError ):
487473 model .simulate (x [0 ], t = 1 )
488474
489- model = SINDy ( discrete_time = True )
475+ model = DiscreteSINDy ( )
490476 with pytest .raises (ValueError ):
491477 model .simulate (x [0 ], t = [1 , 2 ])
492478
493- model = SINDy ( discrete_time = True )
479+ model = DiscreteSINDy ( )
494480 with pytest .raises (ValueError ):
495481 model .simulate (x [0 ], t = - 1 )
496482
497- model = SINDy ( discrete_time = True )
483+ model = DiscreteSINDy ( )
498484 with pytest .raises (ValueError ):
499485 model .simulate (x [0 ], t = 0.5 )
500486
0 commit comments