11import unittest
2+ from unittest .mock import patch , MagicMock
3+ import os
24import torch
35import numpy as np
46from integrators import Integrator , MonteCarlo , MarkovChainMonteCarlo
7+ from integrators import get_ip , get_open_port , setup
8+ from integrators import gaussian , random_walk
59
6- # from base import LinearMap, Uniform
710from maps import Configuration
11+ from base import EPSILON
12+
13+
14+ class TestIntegrators (unittest .TestCase ):
15+ @patch ("socket.socket" )
16+ def test_get_ip (self , mock_socket ):
17+ # Mock the socket behavior
18+ mock_socket_instance = mock_socket .return_value
19+ mock_socket_instance .getsockname .return_value = ("192.168.1.1" , 12345 )
20+ ip = get_ip ()
21+ self .assertEqual (ip , "192.168.1.1" )
22+
23+ @patch ("socket.socket" )
24+ def test_get_open_port (self , mock_socket ):
25+ # Mock the socket behavior
26+ mock_socket_instance = mock_socket .return_value .__enter__ .return_value
27+ mock_socket_instance .getsockname .return_value = ("0.0.0.0" , 54321 )
28+ port = get_open_port ()
29+ self .assertEqual (port , 54321 )
30+
31+ @patch ("torch.distributed.init_process_group" )
32+ @patch ("torch.cuda.set_device" )
33+ @patch .dict (os .environ , {"LOCAL_RANK" : "0" })
34+ def test_setup (self , mock_set_device , mock_init_process_group ):
35+ setup (backend = "gloo" )
36+ mock_init_process_group .assert_called_once_with (backend = "gloo" )
37+ mock_set_device .assert_called_once_with (0 )
38+
39+ def test_random_walk (self ):
40+ # Test random_walk function with default parameters
41+ dim = 3
42+ device = "cpu"
43+ dtype = torch .float32
44+ u = torch .rand (dim , device = device , dtype = dtype )
45+ step_size = 0.2
46+
47+ new_u = random_walk (dim , device , dtype , u , step_size = step_size )
48+
49+ # Validate output shape and range
50+ self .assertEqual (new_u .shape , u .shape )
51+ self .assertTrue (torch .all (new_u >= 0 ) and torch .all (new_u <= 1 ))
52+
53+ # Test with custom step size
54+ custom_step_size = 0.5
55+ new_u_custom = random_walk (dim , device , dtype , u , step_size = custom_step_size )
56+ self .assertEqual (new_u_custom .shape , u .shape )
57+ self .assertTrue (torch .all (new_u_custom >= 0 ) and torch .all (new_u_custom <= 1 ))
58+
59+ def test_gaussian (self ):
60+ # Test gaussian function with default parameters
61+ dim = 3
62+ device = "cpu"
63+ dtype = torch .float32
64+ u = torch .rand (dim , device = device , dtype = dtype )
65+
66+ mean = torch .zeros_like (u )
67+ std = torch .ones_like (u )
68+
69+ new_u = gaussian (dim , device , dtype , u , mean = mean , std = std )
70+
71+ # Validate output shape
72+ self .assertEqual (new_u .shape , u .shape )
73+
74+ # Test with custom mean and std
75+ custom_mean = torch .full_like (u , 0.5 )
76+ custom_std = torch .full_like (u , 0.1 )
77+ new_u_custom = gaussian (dim , device , dtype , u , mean = custom_mean , std = custom_std )
78+
79+ self .assertEqual (new_u_custom .shape , u .shape )
80+ self .assertTrue (torch .all (new_u_custom > custom_mean - 3 * custom_std ))
81+ self .assertTrue (torch .all (new_u_custom < custom_mean + 3 * custom_std ))
882
983
1084class TestIntegrator (unittest .TestCase ):
@@ -25,6 +99,46 @@ def test_initialization(self):
2599 self .assertTrue (hasattr (integrator .maps , "device" ))
26100 self .assertTrue (hasattr (integrator .maps , "dtype" ))
27101
102+ @patch ("MCintegration.maps.CompositeMap" )
103+ @patch ("MCintegration.base.LinearMap" )
104+ def test_initialization_with_maps (self , mock_linear_map , mock_composite_map ):
105+ # Mock the LinearMap and CompositeMap
106+ mock_linear_map_instance = MagicMock ()
107+ mock_linear_map .return_value = mock_linear_map_instance
108+ mock_composite_map_instance = MagicMock ()
109+ mock_composite_map .return_value = mock_composite_map_instance
110+
111+ # Create a mock map
112+ mock_map = MagicMock ()
113+ mock_map .device = "cpu"
114+ mock_map .dtype = torch .float32
115+ mock_map .forward_with_detJ .return_value = (torch .rand (10 , 2 ), torch .rand (10 ))
116+
117+ # Initialize Integrator with maps
118+ integrator = Integrator (
119+ bounds = self .bounds , f = self .f , maps = mock_map , batch_size = self .batch_size
120+ )
121+
122+ # Assertions
123+ self .assertEqual (integrator .dim , 2 )
124+ self .assertEqual (integrator .batch_size , 1000 )
125+ self .assertEqual (integrator .f_dim , 1 )
126+ self .assertTrue (hasattr (integrator .maps , "forward_with_detJ" ))
127+ self .assertTrue (hasattr (integrator .maps , "device" ))
128+ self .assertTrue (hasattr (integrator .maps , "dtype" ))
129+
130+ integrator = Integrator (
131+ bounds = self .bounds ,
132+ f = self .f ,
133+ maps = mock_map ,
134+ batch_size = self .batch_size ,
135+ device = "cpu" ,
136+ )
137+
138+ # Assertions
139+ self .assertEqual (integrator .device , "cpu" )
140+ self .assertTrue (hasattr (integrator .maps , "device" ))
141+
28142 def test_bounds_conversion (self ):
29143 # Test various input types
30144 test_cases = [
@@ -69,11 +183,11 @@ def test_invalid_bounds(self):
69183 with self .assertRaises (error_type ):
70184 Integrator (bounds = bounds , f = self .f )
71185
72- def test_device_handling (self ):
73- if torch .cuda .is_available ():
74- integrator = Integrator (bounds = self .bounds , f = self .f , device = "cuda" )
75- self .assertTrue (integrator .bounds .is_cuda )
76- self .assertTrue (integrator .maps .device == "cuda" )
186+ # def test_device_handling(self):
187+ # if torch.cuda.is_available():
188+ # integrator = Integrator(bounds=self.bounds, f=self.f, device="cuda")
189+ # self.assertTrue(integrator.bounds.is_cuda)
190+ # self.assertTrue(integrator.maps.device == "cuda")
77191
78192 def test_dtype_handling (self ):
79193 dtypes = [torch .float32 , torch .float64 ]
@@ -167,6 +281,27 @@ def test_batch_size_handling(self):
167281 # Should not raise warning
168282 self .mc (neval = neval , nblock = nblock )
169283
284+ def test_block_size_warning (self ):
285+ mc = MonteCarlo (bounds = self .bounds , f = self .simple_integral , batch_size = 1000 )
286+ with self .assertWarns (UserWarning ):
287+ mc (neval = 500 , nblock = 10 ) # neval too small for nblock
288+
289+ def test_varying_nblock (self ):
290+ test_cases = [
291+ (10000 , 10 ), # Standard case
292+ (10000 , 1 ), # Single block
293+ (10000 , 100 ), # Many blocks
294+ ]
295+
296+ for neval , nblock in test_cases :
297+ with self .subTest (neval = neval , nblock = nblock ):
298+ result = self .mc (neval = neval , nblock = nblock )
299+ if hasattr (result , "mean" ):
300+ value = result .mean
301+ else :
302+ value = result
303+ self .assertAlmostEqual (float (value ), 1.0 , delta = 0.1 )
304+
170305
171306class TestMarkovChainMonteCarlo (unittest .TestCase ):
172307 def setUp (self ):
@@ -237,29 +372,40 @@ def test_burnin_effect(self):
237372 value = result
238373 self .assertAlmostEqual (float (value ), 1.0 , delta = tolerance )
239374
240- # def test_mix_rate_sensitivity(self):
241- # # Modified mix rate test to be more robust
242- # mix_rates = [0.0, 0.5, 1.0]
243- # results = []
375+ def test_sample_acceptance (self ):
376+ config = Configuration (
377+ self .mcmc .batch_size ,
378+ self .mcmc .dim ,
379+ self .mcmc .f_dim ,
380+ self .mcmc .device ,
381+ self .mcmc .dtype ,
382+ )
383+ config .u , config .detJ = self .mcmc .q0 .sample_with_detJ (self .mcmc .batch_size )
384+ config .x , detj = self .mcmc .maps .forward_with_detJ (config .u )
385+ config .detJ *= detj
386+ config .weight = torch .rand (self .mcmc .batch_size , device = self .mcmc .device )
244387
245- # for mix_rate in mix_rates:
246- # accumulated_error = 0
247- # n_trials = 3 # Run multiple trials for each mix_rate
388+ self .mcmc .sample (config , nsteps = 1 , mix_rate = 0.5 )
248389
249- # for _ in range(n_trials):
250- # result = self.mcmc(neval=50000, mix_rate=mix_rate, nblock=10)
251- # if hasattr(result, "mean"):
252- # value = result.mean
253- # error = result.sdev
254- # else:
255- # value = result
256- # error = abs(float(value) - 1.0)
257- # accumulated_error += error
390+ # Validate acceptance logic
391+ self .assertTrue (torch .all (config .weight >= EPSILON ))
392+ self .assertEqual (config .u .shape , config .x .shape )
258393
259- # results.append(accumulated_error / n_trials)
394+ def test_varying_mix_rate (self ):
395+ test_cases = [
396+ (0.1 , 0.2 ), # Low mix rate
397+ (0.5 , 0.1 ), # Medium mix rate
398+ (0.9 , 0.05 ), # High mix rate
399+ ]
260400
261- # # We expect moderate mix rates to have lower average error
262- # self.assertLess(results[1], max(results[0], results[2]))
401+ for mix_rate , tolerance in test_cases :
402+ with self .subTest (mix_rate = mix_rate ):
403+ result = self .mcmc (neval = 50000 , mix_rate = mix_rate , nblock = 10 )
404+ if hasattr (result , "mean" ):
405+ value = result .mean
406+ else :
407+ value = result
408+ self .assertAlmostEqual (float (value ), 1.0 , delta = tolerance )
263409
264410
265411class TestDistributedFunctionality (unittest .TestCase ):
@@ -271,26 +417,26 @@ def test_distributed_initialization(self):
271417 self .assertEqual (integrator .rank , 0 )
272418 self .assertEqual (integrator .world_size , 1 )
273419
274- @unittest .skipIf (not torch .distributed .is_available (), "Distributed not available" )
275- def test_multi_gpu_consistency (self ):
276- if torch .cuda .device_count () >= 2 :
277- bounds = torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float64 )
278- f = lambda x , fx : torch .ones_like (x )
420+ # @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
421+ # def test_multi_gpu_consistency(self):
422+ # if torch.cuda.device_count() >= 2:
423+ # bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
424+ # f = lambda x, fx: torch.ones_like(x)
279425
280- # Create two integrators on different devices
281- integrator1 = Integrator (bounds = bounds , f = f , device = "cuda:0" )
282- integrator2 = Integrator (bounds = bounds , f = f , device = "cuda:1" )
426+ # # Create two integrators on different devices
427+ # integrator1 = Integrator(bounds=bounds, f=f, device="cuda:0")
428+ # integrator2 = Integrator(bounds=bounds, f=f, device="cuda:1")
283429
284- # Results should be consistent across devices
285- result1 = integrator1 (neval = 10000 )
286- result2 = integrator2 (neval = 10000 )
430+ # # Results should be consistent across devices
431+ # result1 = integrator1(neval=10000)
432+ # result2 = integrator2(neval=10000)
287433
288- if hasattr (result1 , "mean" ):
289- value1 , value2 = result1 .mean , result2 .mean
290- else :
291- value1 , value2 = result1 , result2
434+ # if hasattr(result1, "mean"):
435+ # value1, value2 = result1.mean, result2.mean
436+ # else:
437+ # value1, value2 = result1, result2
292438
293- self .assertAlmostEqual (float (value1 ), float (value2 ), places = 1 )
439+ # self.assertAlmostEqual(float(value1), float(value2), places=1)
294440
295441
296442if __name__ == "__main__" :
0 commit comments