@@ -137,13 +137,41 @@ def test_forward(self):
137137 self .assertEqual (x .shape , u .shape )
138138 self .assertEqual (log_jac .shape , (u .shape [0 ],))
139139
140+ def test_forward_out_of_bounds (self ):
141+ # Test forward transformation with out-of-bounds u values
142+ u = torch .tensor (
143+ [[1.5 , 0.5 ], [- 0.1 , 0.5 ]], dtype = torch .float64
144+ ) # Out-of-bounds values
145+ x , log_jac = self .vegas .forward (u )
146+
147+ # Check that out-of-bounds x values are clamped to grid boundaries
148+ self .assertTrue (torch .all (x >= 0.0 ))
149+ self .assertTrue (torch .all (x <= 1.0 ))
150+
151+ # Check log determinant adjustment for out-of-bounds cases
152+ self .assertEqual (log_jac .shape , (u .shape [0 ],))
153+
140154 def test_inverse (self ):
141155 # Test inverse transformation
142156 x = torch .tensor ([[0.1 , 0.2 ], [0.3 , 0.4 ]], dtype = torch .float64 )
143157 u , log_jac = self .vegas .inverse (x )
144158 self .assertEqual (u .shape , x .shape )
145159 self .assertEqual (log_jac .shape , (x .shape [0 ],))
146160
161+ def test_inverse_out_of_bounds (self ):
162+ # Test inverse transformation with out-of-bounds x values
163+ x = torch .tensor (
164+ [[1.5 , 0.5 ], [- 0.1 , 0.5 ]], dtype = torch .float64
165+ ) # Out-of-bounds values
166+ u , log_jac = self .vegas .inverse (x )
167+
168+ # Check that out-of-bounds u values are clamped to [0, 1]
169+ self .assertTrue (torch .all (u >= 0.0 ))
170+ self .assertTrue (torch .all (u <= 1.0 ))
171+
172+ # Check log determinant adjustment for out-of-bounds cases
173+ self .assertEqual (log_jac .shape , (x .shape [0 ],))
174+
147175 def test_train (self ):
148176 # Test training the Vegas class
149177 def f (x , fx ):
0 commit comments