Skip to content

Commit b0a6935

Browse files
committed
Added test for the scaling and possible problem with param scaling
1 parent e04bbf4 commit b0a6935

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

ezyrb/plugin/scaler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def __init__(self, scaler, mode, target) -> None:
2626
self.scaler = scaler
2727
self.mode = mode
2828
self.target = target
29+
30+
if target == 'parameters': #TODO
31+
raise NotImplementedError("Scaling of parameters not implemented yet.")
2932

3033
@property
3134
def target(self):
@@ -141,5 +144,5 @@ def rom_postprocessing(self, rom):
141144
db.parameters_matrix,
142145
self.scaler.inverse_transform(self._select_matrix(db)),
143146
)
144-
147+
145148
rom._reduced_database = new_db

tests/test_scaler.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,31 @@ def test_constructor():
1818
pod = POD()
1919
import torch
2020
rbf = RBF()
21-
rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
21+
#rbf = ANN([10, 10], function=torch.nn.Softplus(), stop_training=[1000])
2222
db = Database(param, snapshots.T)
2323
# rom = ROM(db, pod, rbf, plugins=[DatabaseScaler(StandardScaler(), 'full', 'snapshots')])
2424
rom = ROM(db, pod, rbf, plugins=[
25-
DatabaseScaler(StandardScaler(), 'full', 'parameters'),
25+
#DatabaseScaler(StandardScaler(), 'reduced', 'parameters'),
2626
DatabaseScaler(StandardScaler(), 'reduced', 'snapshots')
2727
])
2828
rom.fit()
29-
print(rom.predict(rom.database.parameters_matrix))
30-
print(rom.database.snapshots_matrix)
29+
30+
3131

3232

33-
# def test_values():
34-
# snap = Snapshot(test_value)
35-
# snap.values = test_value
36-
# snap = Snapshot(test_value, space=test_space)
37-
# with pytest.raises(ValueError):
38-
# snap.values = test_value[:-2]
33+
def test_values():
34+
pod = POD()
35+
rbf = RBF()
36+
db = Database(param, snapshots.T)
37+
rom = ROM(db, pod, rbf, plugins=[
38+
DatabaseScaler(StandardScaler(), 'reduced', 'snapshots')
39+
])
40+
rom.fit()
41+
test_param = param[2]
42+
truth_sol = db.snapshots_matrix[2]
43+
predicted_sol = rom.predict(test_param).snapshots_matrix[0]
44+
print(predicted_sol)
45+
print(truth_sol)
46+
np.testing.assert_allclose(predicted_sol, truth_sol,
47+
rtol=1e-5, atol=1e-5)
3948

0 commit comments

Comments
 (0)