|
2 | 2 | import pytest |
3 | 3 | from numpy.testing import assert_array_equal |
4 | 4 |
|
| 5 | +from pysindy.utils import AxesArray |
5 | 6 | from pysindy.utils import get_prox |
6 | 7 | from pysindy.utils import get_regularization |
7 | 8 | from pysindy.utils import reorder_constraints |
| 9 | +from pysindy.utils import validate_control_variables |
8 | 10 |
|
9 | 11 |
|
10 | 12 | def test_reorder_constraints_1D(): |
@@ -48,6 +50,18 @@ def test_reorder_constraints_2D(): |
48 | 50 | np.testing.assert_array_equal(result, target_order) |
49 | 51 |
|
50 | 52 |
|
| 53 | +def test_validate_controls(): |
| 54 | + with pytest.raises(ValueError): |
| 55 | + validate_control_variables(1, []) |
| 56 | + with pytest.raises(ValueError): |
| 57 | + validate_control_variables([], 1) |
| 58 | + with pytest.raises(ValueError): |
| 59 | + validate_control_variables([], [1]) |
| 60 | + arr = AxesArray(np.ones(4).reshape((2, 2)), axes={"ax_time": 0, "ax_coord": 1}) |
| 61 | + with pytest.raises(ValueError): |
| 62 | + validate_control_variables([arr], [arr[:1]]) |
| 63 | + |
| 64 | + |
51 | 65 | @pytest.mark.parametrize( |
52 | 66 | ["regularization", "lam", "expected"], |
53 | 67 | [ |
|
0 commit comments