Skip to content

Commit 5ab3b94

Browse files
committed
Added LANDO testing
1 parent 66563fb commit 5ab3b94

File tree

1 file changed

+274
-13
lines changed

1 file changed

+274
-13
lines changed

tests/test_lando.py

Lines changed: 274 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from pytest import raises, warns
23
from scipy.integrate import solve_ivp
34
from numpy.testing import assert_allclose, assert_equal
45

@@ -77,12 +78,30 @@ def lorenz_system(state_t, state):
7778
return sol.y
7879

7980

81+
def dummy_kernel(Xk, Yk):
82+
"""
83+
Externally-defined linear kernel function.
84+
"""
85+
return Xk.T.dot(Yk)
86+
87+
88+
def dummy_grad(Xk, yk):
89+
"""
90+
Externally-defined linear kernel function gradient.
91+
"""
92+
return Xk.T
93+
94+
8095
# Generate Lorenz system data.
8196
dt = 0.001
8297
t = np.arange(0, 10, dt)
8398
X = generate_lorenz_data(t)
8499
Y = differentiate(X, dt)
85100

101+
# Generate short version of the data.
102+
X_short = X[:, :2000]
103+
Y_short = Y[:, :2000]
104+
86105
# Set the LANDO learning parameters used by most test models.
87106
lando_params = {}
88107
lando_params["svd_rank"] = 3
@@ -100,7 +119,7 @@ def test_fitted():
100119
assert not lando.partially_fitted
101120
assert not lando.fitted
102121

103-
lando.fit(X, Y)
122+
lando.fit(X_short, Y_short)
104123
assert lando.partially_fitted
105124
assert not lando.fitted
106125

@@ -109,15 +128,19 @@ def test_fitted():
109128
assert lando.fitted
110129

111130

112-
def test_shapes():
131+
def test_sparse_dictionary():
113132
"""
114-
Test that the shapes of the sparse dictionary and the sparse dictionary
115-
weights are as expected.
133+
Test that the shapes of the sparse dictionary and that the sparse
134+
dictionary weights are as expected.
116135
"""
117136
lando = LANDO(**lando_params)
118-
lando.fit(X, Y)
119-
assert X.shape[0] == lando.sparse_dictionary.shape[0]
120-
assert X.shape[-1] > lando.sparse_dictionary.shape[-1]
137+
138+
with raises(ValueError):
139+
_ = lando.sparse_dictionary
140+
141+
lando.fit(X_short, Y_short)
142+
assert X_short.shape[0] == lando.sparse_dictionary.shape[0]
143+
assert X_short.shape[-1] > lando.sparse_dictionary.shape[-1]
121144
assert lando.operator.weights.shape == lando.sparse_dictionary.shape
122145

123146

@@ -206,7 +229,7 @@ def test_predict_2():
206229
dt=dt,
207230
solve_ivp_opts=solve_ivp_opts,
208231
)
209-
assert relative_error(lando_predict, X_long) < 0.05
232+
assert relative_error(lando_predict, X_long) < 0.1
210233

211234

212235
def test_predict_3():
@@ -240,7 +263,7 @@ def test_predict_4():
240263
tend=len(t),
241264
continuous=False,
242265
)
243-
assert relative_error(lando_predict, X) < 0.05
266+
assert relative_error(lando_predict, X) < 0.1
244267

245268

246269
def test_online_1():
@@ -254,7 +277,7 @@ def test_online_1():
254277
lando_online = LANDO(online=True, **lando_params)
255278
lando_online.fit(X, Y)
256279

257-
assert relative_error(lando_online.f(X), lando.f(X)) < 1e-5
280+
assert relative_error(lando_online.f(X), lando.f(X)) < 1e-3
258281

259282

260283
def test_online_2():
@@ -270,7 +293,7 @@ def test_online_2():
270293
lando_online.fit(X[:, :batch_split], Y[:, :batch_split])
271294
lando_online.update(X[:, batch_split:], Y[:, batch_split:])
272295

273-
assert relative_error(lando_online.f(X), lando.f(X)) < 1e-5
296+
assert relative_error(lando_online.f(X), lando.f(X)) < 1e-3
274297

275298

276299
def test_online_3():
@@ -290,5 +313,243 @@ def test_online_3():
290313

291314
assert_equal(lando_online.fixed_point, lando.fixed_point)
292315
assert np.linalg.norm(lando_online.bias) < 1e-3
293-
assert relative_error(lando_online.linear, lando.linear) < 1e-4
294-
assert relative_error(lando_online.nonlinear(X), lando.nonlinear(X)) < 1e-4
316+
assert relative_error(lando_online.linear, lando.linear) < 1e-3
317+
assert relative_error(lando_online.nonlinear(X), lando.nonlinear(X)) < 1e-3
318+
319+
320+
def test_default_kernel():
321+
"""
322+
Test that there are no errors when calling the default linear kernel.
323+
"""
324+
lando = LANDO()
325+
lando.fit(X_short, Y_short)
326+
lando.analyze_fixed_point(x_bar)
327+
328+
329+
def test_rbf_kernel():
330+
"""
331+
Test that there are no errors when calling the default RBF kernel.
332+
"""
333+
lando = LANDO(kernel_metric="rbf")
334+
lando.fit(X_short, Y_short)
335+
lando.analyze_fixed_point(x_bar)
336+
337+
338+
def test_custom_kernel():
339+
"""
340+
Test that there are no errors when using custom kernel functions.
341+
"""
342+
lando = LANDO(kernel_function=dummy_kernel, kernel_gradient=dummy_grad)
343+
lando.fit(X_short, Y_short)
344+
lando.analyze_fixed_point(x_bar)
345+
346+
347+
def test_custom_kernel_error():
348+
"""
349+
Test that an error occurs if a user attempts a fixed point analysis with a
350+
custom kernel function but without a kernel gradient function.
351+
"""
352+
lando = LANDO(kernel_function=dummy_kernel)
353+
lando.fit(X_short, Y_short)
354+
355+
with raises(ValueError):
356+
lando.analyze_fixed_point(x_bar)
357+
358+
359+
def test_kernel_inputs():
360+
"""
361+
Tests various errors caught by the test_kernel_inputs function.
362+
"""
363+
# Error should be thrown if an invalid kernel metric is given.
364+
with raises(ValueError):
365+
_ = LANDO(kernel_metric="blah")
366+
367+
# Error should be thrown if kernel_params isn't a dict.
368+
with raises(TypeError):
369+
_ = LANDO(kernel_metric="poly", kernel_params=3)
370+
371+
# Error should be thrown if kernel_params contains invalid entries.
372+
with raises(ValueError):
373+
_ = LANDO(kernel_metric="poly", kernel_params={"blah": 3})
374+
375+
376+
def test_kernel_functions_1():
377+
"""
378+
Tests various errors caught by the test_kernel_functions function.
379+
Tests for errors related to invalid inputs and combinations.
380+
"""
381+
# Warning should arise if a kernel function is given without a gradient.
382+
with warns():
383+
_ = LANDO(kernel_function=dummy_kernel)
384+
385+
# Error should be thrown if a gradient is given without a kernel function.
386+
with raises(ValueError):
387+
_ = LANDO(kernel_gradient=dummy_grad)
388+
389+
# Error should be thrown if kernel_function isn't a function.
390+
with raises(TypeError):
391+
_ = LANDO(kernel_function=0, kernel_gradient=dummy_grad)
392+
393+
# Error should be thrown if kernel_gradient isn't a function.
394+
with raises(TypeError):
395+
_ = LANDO(kernel_function=dummy_kernel, kernel_gradient=0)
396+
397+
398+
def test_kernel_functions_2():
399+
"""
400+
Tests various errors caught by the test_kernel_functions function.
401+
Tests for errors related to invalid function inputs.
402+
"""
403+
404+
# Define functions that malfunction when called:
405+
def bad_kernel_1(Xk, Yk):
406+
return dummy_kernel(Xk.T, Yk)
407+
408+
def bad_grad_1(Xk, yk):
409+
return dummy_grad(Xk.T, yk)
410+
411+
# Define functions that yield incorrect dimensions:
412+
def bad_kernel_2(Xk, Yk):
413+
return dummy_kernel(Xk, Yk).T
414+
415+
def bad_grad_2(Xk, yk):
416+
return dummy_grad(Xk, yk).T
417+
418+
with raises(ValueError):
419+
_ = LANDO(kernel_function=bad_kernel_1, kernel_gradient=dummy_grad)
420+
421+
with raises(ValueError):
422+
_ = LANDO(kernel_function=bad_kernel_2, kernel_gradient=dummy_grad)
423+
424+
with raises(ValueError):
425+
_ = LANDO(kernel_function=dummy_kernel, kernel_gradient=bad_grad_1)
426+
427+
with raises(ValueError):
428+
_ = LANDO(kernel_function=dummy_kernel, kernel_gradient=bad_grad_2)
429+
430+
431+
def test_supported_kernels():
432+
"""
433+
Test a call to supported kernels.
434+
"""
435+
lando = LANDO()
436+
print(lando.supported_kernels)
437+
438+
439+
def test_errors_f():
440+
"""
441+
Test that expected errors are thrown when calling f.
442+
"""
443+
lando = LANDO()
444+
445+
# Error should be thrown if f is called prior to fitting.
446+
with raises(ValueError):
447+
_ = lando.f(X_short)
448+
449+
lando.fit(X_short, Y_short)
450+
451+
# Error should be thrown if f is given data with the wrong dimension.
452+
with raises(ValueError):
453+
_ = lando.f(X_short[:-1])
454+
455+
456+
def test_errors_predict():
457+
"""
458+
Test that expected errors are thrown when calling predict.
459+
"""
460+
lando = LANDO()
461+
462+
# Error should be thrown if called prior to fitting.
463+
with raises(ValueError):
464+
_ = lando.predict(x0=(-8, 8, 27), tend=len(t))
465+
466+
lando.fit(X_short, Y_short)
467+
468+
# Error should be thrown if data is the wrong dimension.
469+
with raises(ValueError):
470+
_ = lando.predict(x0=(-8, 8), tend=len(t))
471+
472+
473+
def test_errors_fixed_point():
474+
"""
475+
Test that expected errors are thrown when calling analyze_fixed_point.
476+
"""
477+
lando = LANDO()
478+
479+
# Error should be thrown if called prior to fitting.
480+
with raises(ValueError):
481+
lando.analyze_fixed_point(x_bar)
482+
483+
lando.fit(X_short, Y_short)
484+
485+
# Error should be thrown if fixed point is the wrong dimension.
486+
with raises(ValueError):
487+
lando.analyze_fixed_point(x_bar[:-1])
488+
489+
490+
def test_errors_update():
491+
"""
492+
Test that expected errors are thrown when calling update.
493+
"""
494+
lando = LANDO()
495+
496+
# Error should be thrown if called prior to fitting.
497+
with raises(ValueError):
498+
lando.update(X_short, Y_short)
499+
500+
lando.fit(X_short, Y_short)
501+
502+
# Error should be thrown if data is the wrong dimension.
503+
with raises(ValueError):
504+
lando.update(X_short, Y_short[:, :-1])
505+
506+
507+
def test_errors_get():
508+
"""
509+
Test that errors are thrown if the following are attempted to be retrieved
510+
prior to fully fitting: fixed_point, bias, linear, nonlinear.
511+
"""
512+
lando = LANDO()
513+
514+
# Errors should be thrown if called prior to fitting:
515+
with raises(ValueError):
516+
_ = lando.fixed_point
517+
518+
with raises(ValueError):
519+
_ = lando.bias
520+
521+
with raises(ValueError):
522+
_ = lando.linear
523+
524+
with raises(ValueError):
525+
_ = lando.nonlinear(X_short)
526+
527+
lando.fit(X_short, Y_short)
528+
529+
# Errors should still be thrown after a call to just fit:
530+
with raises(ValueError):
531+
_ = lando.fixed_point
532+
533+
with raises(ValueError):
534+
_ = lando.bias
535+
536+
with raises(ValueError):
537+
_ = lando.linear
538+
539+
with raises(ValueError):
540+
_ = lando.nonlinear(X_short)
541+
542+
543+
def test_reconstructed_data():
544+
"""
545+
Test reconstructed data shape and output.
546+
"""
547+
lando = LANDO()
548+
lando.fit(X_short, Y_short)
549+
lando.analyze_fixed_point(x_bar)
550+
551+
# Calling for reconstructed data should yield a warning.
552+
with warns():
553+
X_recon = lando.reconstructed_data
554+
555+
assert X_recon.shape == X_short.shape

0 commit comments

Comments
 (0)