@@ -18,22 +18,31 @@ def test_constructor():
1818 pod = POD ()
1919 import torch
2020 rbf = RBF ()
21- rbf = ANN ([10 , 10 ], function = torch .nn .Softplus (), stop_training = [1000 ])
21+ # rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
2222 db = Database (param , snapshots .T )
2323 # rom = ROM(db, pod, rbf, plugins=[DatabaseScaler(StandardScaler(), 'full', 'snapshots')])
2424 rom = ROM (db , pod , rbf , plugins = [
25- DatabaseScaler (StandardScaler (), 'full ' , 'parameters' ),
25+ # DatabaseScaler(StandardScaler(), 'reduced ', 'parameters'),
2626 DatabaseScaler (StandardScaler (), 'reduced' , 'snapshots' )
2727 ])
2828 rom .fit ()
29- print ( rom . predict ( rom . database . parameters_matrix ))
30- print ( rom . database . snapshots_matrix )
29+
30+
3131
3232
33- # def test_values():
34- # snap = Snapshot(test_value)
35- # snap.values = test_value
36- # snap = Snapshot(test_value, space=test_space)
37- # with pytest.raises(ValueError):
38- # snap.values = test_value[:-2]
33+ def test_values ():
34+ pod = POD ()
35+ rbf = RBF ()
36+ db = Database (param , snapshots .T )
37+ rom = ROM (db , pod , rbf , plugins = [
38+ DatabaseScaler (StandardScaler (), 'reduced' , 'snapshots' )
39+ ])
40+ rom .fit ()
41+ test_param = param [2 ]
42+ truth_sol = db .snapshots_matrix [2 ]
43+ predicted_sol = rom .predict (test_param ).snapshots_matrix [0 ]
44+ print (predicted_sol )
45+ print (truth_sol )
46+ np .testing .assert_allclose (predicted_sol , truth_sol ,
47+ rtol = 1e-5 , atol = 1e-5 )
3948
0 commit comments