11import numpy as np
2+ from pytest import raises , warns
23from scipy .integrate import solve_ivp
34from numpy .testing import assert_allclose , assert_equal
45
@@ -77,12 +78,30 @@ def lorenz_system(state_t, state):
7778 return sol .y
7879
7980
81+ def dummy_kernel (Xk , Yk ):
82+ """
83+ Externally-defined linear kernel function.
84+ """
85+ return Xk .T .dot (Yk )
86+
87+
88+ def dummy_grad (Xk , yk ):
89+ """
90+ Externally-defined linear kernel function gradient.
91+ """
92+ return Xk .T
93+
94+
8095# Generate Lorenz system data.
8196dt = 0.001
8297t = np .arange (0 , 10 , dt )
8398X = generate_lorenz_data (t )
8499Y = differentiate (X , dt )
85100
101+ # Generate short version of the data.
102+ X_short = X [:, :2000 ]
103+ Y_short = Y [:, :2000 ]
104+
86105# Set the LANDO learning parameters used by most test models.
87106lando_params = {}
88107lando_params ["svd_rank" ] = 3
@@ -100,7 +119,7 @@ def test_fitted():
100119 assert not lando .partially_fitted
101120 assert not lando .fitted
102121
103- lando .fit (X , Y )
122+ lando .fit (X_short , Y_short )
104123 assert lando .partially_fitted
105124 assert not lando .fitted
106125
@@ -109,15 +128,19 @@ def test_fitted():
109128 assert lando .fitted
110129
111130
112- def test_shapes ():
131+ def test_sparse_dictionary ():
113132 """
114- Test that the shapes of the sparse dictionary and the sparse dictionary
115- weights are as expected.
133+ Test that the shapes of the sparse dictionary and that the sparse
134+ dictionary weights are as expected.
116135 """
117136 lando = LANDO (** lando_params )
118- lando .fit (X , Y )
119- assert X .shape [0 ] == lando .sparse_dictionary .shape [0 ]
120- assert X .shape [- 1 ] > lando .sparse_dictionary .shape [- 1 ]
137+
138+ with raises (ValueError ):
139+ _ = lando .sparse_dictionary
140+
141+ lando .fit (X_short , Y_short )
142+ assert X_short .shape [0 ] == lando .sparse_dictionary .shape [0 ]
143+ assert X_short .shape [- 1 ] > lando .sparse_dictionary .shape [- 1 ]
121144 assert lando .operator .weights .shape == lando .sparse_dictionary .shape
122145
123146
@@ -206,7 +229,7 @@ def test_predict_2():
206229 dt = dt ,
207230 solve_ivp_opts = solve_ivp_opts ,
208231 )
209- assert relative_error (lando_predict , X_long ) < 0.05
232+ assert relative_error (lando_predict , X_long ) < 0.1
210233
211234
212235def test_predict_3 ():
@@ -240,7 +263,7 @@ def test_predict_4():
240263 tend = len (t ),
241264 continuous = False ,
242265 )
243- assert relative_error (lando_predict , X ) < 0.05
266+ assert relative_error (lando_predict , X ) < 0.1
244267
245268
246269def test_online_1 ():
@@ -254,7 +277,7 @@ def test_online_1():
254277 lando_online = LANDO (online = True , ** lando_params )
255278 lando_online .fit (X , Y )
256279
257- assert relative_error (lando_online .f (X ), lando .f (X )) < 1e-5
280+ assert relative_error (lando_online .f (X ), lando .f (X )) < 1e-3
258281
259282
260283def test_online_2 ():
@@ -270,7 +293,7 @@ def test_online_2():
270293 lando_online .fit (X [:, :batch_split ], Y [:, :batch_split ])
271294 lando_online .update (X [:, batch_split :], Y [:, batch_split :])
272295
273- assert relative_error (lando_online .f (X ), lando .f (X )) < 1e-5
296+ assert relative_error (lando_online .f (X ), lando .f (X )) < 1e-3
274297
275298
276299def test_online_3 ():
@@ -290,5 +313,243 @@ def test_online_3():
290313
291314 assert_equal (lando_online .fixed_point , lando .fixed_point )
292315 assert np .linalg .norm (lando_online .bias ) < 1e-3
293- assert relative_error (lando_online .linear , lando .linear ) < 1e-4
294- assert relative_error (lando_online .nonlinear (X ), lando .nonlinear (X )) < 1e-4
316+ assert relative_error (lando_online .linear , lando .linear ) < 1e-3
317+ assert relative_error (lando_online .nonlinear (X ), lando .nonlinear (X )) < 1e-3
318+
319+
320+ def test_default_kernel ():
321+ """
322+ Test that there are no errors when calling the default linear kernel.
323+ """
324+ lando = LANDO ()
325+ lando .fit (X_short , Y_short )
326+ lando .analyze_fixed_point (x_bar )
327+
328+
329+ def test_rbf_kernel ():
330+ """
331+ Test that there are no errors when calling the default RBF kernel.
332+ """
333+ lando = LANDO (kernel_metric = "rbf" )
334+ lando .fit (X_short , Y_short )
335+ lando .analyze_fixed_point (x_bar )
336+
337+
338+ def test_custom_kernel ():
339+ """
340+ Test that there are no errors when using custom kernel functions.
341+ """
342+ lando = LANDO (kernel_function = dummy_kernel , kernel_gradient = dummy_grad )
343+ lando .fit (X_short , Y_short )
344+ lando .analyze_fixed_point (x_bar )
345+
346+
347+ def test_custom_kernel_error ():
348+ """
349+ Test that an error occurs if a user attempts a fixed point analysis with a
350+ custom kernel function but without a kernel gradient function.
351+ """
352+ lando = LANDO (kernel_function = dummy_kernel )
353+ lando .fit (X_short , Y_short )
354+
355+ with raises (ValueError ):
356+ lando .analyze_fixed_point (x_bar )
357+
358+
359+ def test_kernel_inputs ():
360+ """
361+ Tests various errors caught by the test_kernel_inputs function.
362+ """
363+ # Error should be thrown if an invalid kernel metric is given.
364+ with raises (ValueError ):
365+ _ = LANDO (kernel_metric = "blah" )
366+
367+ # Error should be thrown if kernel_params isn't a dict.
368+ with raises (TypeError ):
369+ _ = LANDO (kernel_metric = "poly" , kernel_params = 3 )
370+
371+ # Error should be thrown if kernel_params contains invalid entries.
372+ with raises (ValueError ):
373+ _ = LANDO (kernel_metric = "poly" , kernel_params = {"blah" : 3 })
374+
375+
376+ def test_kernel_functions_1 ():
377+ """
378+ Tests various errors caught by the test_kernel_functions function.
379+ Tests for errors related to invalid inputs and combinations.
380+ """
381+ # Warning should arise if a kernel function is given without a gradient.
382+ with warns ():
383+ _ = LANDO (kernel_function = dummy_kernel )
384+
385+ # Error should be thrown if a gradient is given without a kernel function.
386+ with raises (ValueError ):
387+ _ = LANDO (kernel_gradient = dummy_grad )
388+
389+ # Error should be thrown if kernel_function isn't a function.
390+ with raises (TypeError ):
391+ _ = LANDO (kernel_function = 0 , kernel_gradient = dummy_grad )
392+
393+ # Error should be thrown if kernel_gradient isn't a function.
394+ with raises (TypeError ):
395+ _ = LANDO (kernel_function = dummy_kernel , kernel_gradient = 0 )
396+
397+
398+ def test_kernel_functions_2 ():
399+ """
400+ Tests various errors caught by the test_kernel_functions function.
401+ Tests for errors related to invalid function inputs.
402+ """
403+
404+ # Define functions that malfunction when called:
405+ def bad_kernel_1 (Xk , Yk ):
406+ return dummy_kernel (Xk .T , Yk )
407+
408+ def bad_grad_1 (Xk , yk ):
409+ return dummy_grad (Xk .T , yk )
410+
411+ # Define functions that yield incorrect dimensions:
412+ def bad_kernel_2 (Xk , Yk ):
413+ return dummy_kernel (Xk , Yk ).T
414+
415+ def bad_grad_2 (Xk , yk ):
416+ return dummy_grad (Xk , yk ).T
417+
418+ with raises (ValueError ):
419+ _ = LANDO (kernel_function = bad_kernel_1 , kernel_gradient = dummy_grad )
420+
421+ with raises (ValueError ):
422+ _ = LANDO (kernel_function = bad_kernel_2 , kernel_gradient = dummy_grad )
423+
424+ with raises (ValueError ):
425+ _ = LANDO (kernel_function = dummy_kernel , kernel_gradient = bad_grad_1 )
426+
427+ with raises (ValueError ):
428+ _ = LANDO (kernel_function = dummy_kernel , kernel_gradient = bad_grad_2 )
429+
430+
431+ def test_supported_kernels ():
432+ """
433+ Test a call to supported kernels.
434+ """
435+ lando = LANDO ()
436+ print (lando .supported_kernels )
437+
438+
439+ def test_errors_f ():
440+ """
441+ Test that expected errors are thrown when calling f.
442+ """
443+ lando = LANDO ()
444+
445+ # Error should be thrown if f is called prior to fitting.
446+ with raises (ValueError ):
447+ _ = lando .f (X_short )
448+
449+ lando .fit (X_short , Y_short )
450+
451+ # Error should be thrown if f is given data with the wrong dimension.
452+ with raises (ValueError ):
453+ _ = lando .f (X_short [:- 1 ])
454+
455+
456+ def test_errors_predict ():
457+ """
458+ Test that expected errors are thrown when calling predict.
459+ """
460+ lando = LANDO ()
461+
462+ # Error should be thrown if called prior to fitting.
463+ with raises (ValueError ):
464+ _ = lando .predict (x0 = (- 8 , 8 , 27 ), tend = len (t ))
465+
466+ lando .fit (X_short , Y_short )
467+
468+ # Error should be thrown if data is the wrong dimension.
469+ with raises (ValueError ):
470+ _ = lando .predict (x0 = (- 8 , 8 ), tend = len (t ))
471+
472+
473+ def test_errors_fixed_point ():
474+ """
475+ Test that expected errors are thrown when calling analyze_fixed_point.
476+ """
477+ lando = LANDO ()
478+
479+ # Error should be thrown if called prior to fitting.
480+ with raises (ValueError ):
481+ lando .analyze_fixed_point (x_bar )
482+
483+ lando .fit (X_short , Y_short )
484+
485+ # Error should be thrown if fixed point is the wrong dimension.
486+ with raises (ValueError ):
487+ lando .analyze_fixed_point (x_bar [:- 1 ])
488+
489+
490+ def test_errors_update ():
491+ """
492+ Test that expected errors are thrown when calling update.
493+ """
494+ lando = LANDO ()
495+
496+ # Error should be thrown if called prior to fitting.
497+ with raises (ValueError ):
498+ lando .update (X_short , Y_short )
499+
500+ lando .fit (X_short , Y_short )
501+
502+ # Error should be thrown if data is the wrong dimension.
503+ with raises (ValueError ):
504+ lando .update (X_short , Y_short [:, :- 1 ])
505+
506+
507+ def test_errors_get ():
508+ """
509+ Test that errors are thrown if the following are attempted to be retrieved
510+ prior to fully fitting: fixed_point, bias, linear, nonlinear.
511+ """
512+ lando = LANDO ()
513+
514+ # Errors should be thrown if called prior to fitting:
515+ with raises (ValueError ):
516+ _ = lando .fixed_point
517+
518+ with raises (ValueError ):
519+ _ = lando .bias
520+
521+ with raises (ValueError ):
522+ _ = lando .linear
523+
524+ with raises (ValueError ):
525+ _ = lando .nonlinear (X_short )
526+
527+ lando .fit (X_short , Y_short )
528+
529+ # Errors should still be thrown after a call to just fit:
530+ with raises (ValueError ):
531+ _ = lando .fixed_point
532+
533+ with raises (ValueError ):
534+ _ = lando .bias
535+
536+ with raises (ValueError ):
537+ _ = lando .linear
538+
539+ with raises (ValueError ):
540+ _ = lando .nonlinear (X_short )
541+
542+
543+ def test_reconstructed_data ():
544+ """
545+ Test reconstructed data shape and output.
546+ """
547+ lando = LANDO ()
548+ lando .fit (X_short , Y_short )
549+ lando .analyze_fixed_point (x_bar )
550+
551+ # Calling for reconstructed data should yield a warning.
552+ with warns ():
553+ X_recon = lando .reconstructed_data
554+
555+ assert X_recon .shape == X_short .shape
0 commit comments