55import sys
66import traceback
77from MCintegration import MonteCarlo , MarkovChainMonteCarlo , Vegas
8+
89os .environ ["NCCL_DEBUG" ] = "OFF"
910os .environ ["TORCH_DISTRIBUTED_DEBUG" ] = "OFF"
1011os .environ ["GLOG_minloglevel" ] = "2"
1112os .environ ["MASTER_ADDR" ] = os .getenv ("MASTER_ADDR" , "localhost" )
1213os .environ ["MASTER_PORT" ] = os .getenv ("MASTER_PORT" , "12355" )
1314
1415backend = "nccl"
16+ # backend = "gloo"
17+
1518
1619def init_process (rank , world_size , fn , backend = backend ):
1720 try :
@@ -23,6 +26,7 @@ def init_process(rank, world_size, fn, backend=backend):
2326 dist .destroy_process_group ()
2427 raise e
2528
29+
2630def run_mcmc (rank , world_size ):
2731 try :
2832 if rank != 0 :
@@ -42,7 +46,12 @@ def func(x, f):
4246 ninc = 1000
4347 n_therm = 20
4448
45- device = torch .device (f"cuda:{ rank } " )
49+ if backend == "gloo" :
50+ device = torch .device ("cpu" )
51+ elif backend == "nccl" :
52+ device = torch .device (f"cuda:{ rank } " )
53+ else :
54+ raise ValueError (f"Invalid backend: { backend } " )
4655
4756 print (f"Process { rank } using device: { device } " )
4857
@@ -54,26 +63,32 @@ def func(x, f):
5463
5564 print ("Integration Results for log(x)/sqrt(x):" )
5665
57-
5866 # Plain MC Integration
59- mc_integrator = MonteCarlo (bounds , func , batch_size = batch_size ,device = device )
67+ mc_integrator = MonteCarlo (bounds , func , batch_size = batch_size , device = device )
6068 print ("Plain MC Integral Result:" , mc_integrator (n_eval ))
6169
6270 # MCMC Integration
6371 mcmc_integrator = MarkovChainMonteCarlo (
64- bounds , func , batch_size = batch_size , nburnin = n_therm ,device = device
72+ bounds , func , batch_size = batch_size , nburnin = n_therm , device = device
6573 )
6674 print ("MCMC Integral Result:" , mcmc_integrator (n_eval , mix_rate = 0.5 ))
6775
6876 # Perform VEGAS integration
69- vegas_integrator = MonteCarlo (bounds , func , maps = vegas_map , batch_size = batch_size ,device = device )
77+ vegas_integrator = MonteCarlo (
78+ bounds , func , maps = vegas_map , batch_size = batch_size , device = device
79+ )
7080 res = vegas_integrator (n_eval )
7181
7282 print ("VEGAS Integral Result:" , res )
7383
7484 # VEGAS-MCMC Integration
7585 vegasmcmc_integrator = MarkovChainMonteCarlo (
76- bounds , func , maps = vegas_map , batch_size = batch_size , nburnin = n_therm ,device = device
86+ bounds ,
87+ func ,
88+ maps = vegas_map ,
89+ batch_size = batch_size ,
90+ nburnin = n_therm ,
91+ device = device ,
7792 )
7893 res_vegasmcmc = vegasmcmc_integrator (n_eval , mix_rate = 0.5 )
7994 print ("VEGAS-MCMC Integral Result:" , res_vegasmcmc )
@@ -100,6 +115,7 @@ def test_mcmc(world_size):
100115 except Exception as e :
101116 print (f"Error in test_mcmc: { e } " )
102117
118+
103119if __name__ == "__main__" :
104120 mp .set_start_method ("spawn" , force = True )
105- test_mcmc (4 )
121+ test_mcmc (4 )
0 commit comments