@@ -31,3 +31,42 @@ def test_serialize_deserialize(adapter, custom_objects, random_data):
3131 deserialized_processed = deserialized (random_data )
3232 for key , value in processed .items ():
3333 assert np .allclose (value , deserialized_processed [key ])
34+
35+
36+ def test_constrain ():
37+ import numpy as np
38+ import warnings
39+ from bayesflow .adapters import Adapter
40+
41+ data = {
42+ "x1" : np .random .exponential (1 , size = (32 , 1 )),
43+ "x2" : - np .random .exponential (1 , size = (32 , 1 )),
44+ "x3" : np .random .beta (0.5 , 0.5 , size = (32 , 1 )),
45+ "x4" : np .vstack ((np .zeros (shape = (16 , 1 )), np .ones (shape = (16 , 1 )))),
46+ "x5" : np .zeros (shape = (32 , 1 )),
47+ "x6" : np .zeros (shape = (32 , 1 )),
48+ }
49+
50+ adapter = (
51+ Adapter ()
52+ .constrain ("x1" , lower = 0 )
53+ .constrain ("x2" , upper = 0 )
54+ .constrain ("x3" , lower = 0 , upper = 1 )
55+ .constrain ("x4" , lower = 0 , upper = 1 , inclusive = "both" )
56+ .constrain ("x5" , lower = 0 , inclusive = "none" )
57+ .constrain ("x6" , upper = 0 , inclusive = "none" )
58+ )
59+
60+ with warnings .catch_warnings ():
61+ warnings .simplefilter ("ignore" , RuntimeWarning )
62+ result = adapter (data )
63+
64+ # checks if transformations indeed have been applied
65+ assert result ["x1" ].min () < 0.0
66+ assert result ["x2" ].max () > 0.0
67+ assert result ["x3" ].min () < 0.0
68+ assert result ["x3" ].max () > 1.0
69+ assert np .isfinite (result ["x4" ].min ())
70+ assert np .isfinite (result ["x4" ].max ())
71+ assert np .isneginf (result ["x5" ][0 ])
72+ assert np .isinf (result ["x6" ][0 ])
0 commit comments