Skip to content

Commit 6c8744b

Browse files
committed
Fix inverse and add to tests
1 parent 612123c commit 6c8744b

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

bayesflow/adapters/transforms/nnpe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
5757
return data + noise
5858

5959
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
60-
return np.exp(data)
60+
return data
6161

6262
def get_config(self) -> dict:
6363
return serialize({"slab_scale": self.slab_scale, "spike_scale": self.spike_scale, "seed": self.seed})

tests/test_adapters/test_adapters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def test_nnpe(random_data):
308308
result_training = ad(random_data, stage="training")
309309
result_validation = ad(random_data, stage="validation")
310310
result_inference = ad(random_data, stage="inference")
311+
result_inversed = ad(random_data, inverse=True)
311312
serialized = serialize(ad)
312313
deserialized = deserialize(serialized)
313314
reserialized = serialize(deserialized)
@@ -324,7 +325,8 @@ def test_nnpe(random_data):
324325
continue
325326
assert np.allclose(result_training[k], v)
326327

327-
# check that the validation and inference data is unchanged
328+
# check that the validation and inference data as well as inversed results are unchanged
328329
for k, v in random_data.items():
329330
assert np.allclose(result_validation[k], v)
330331
assert np.allclose(result_inference[k], v)
332+
assert np.allclose(result_inversed[k], v)

0 commit comments

Comments
 (0)