File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 1515print ("JAX devices:" , jax .devices ())
1616
1717# Set double precision
18- torch .set_default_dtype (torch .float64 )
19- jax .config .update ("jax_enable_x64" , True )
18+ torch .set_default_dtype (torch .float32 )
19+ jax .config .update ("jax_enable_x64" , False )
2020
2121
2222def run_cosmo_tt_example (
@@ -153,7 +153,7 @@ def ln_posterior_torch_maybe(theta_torch):
153153 # ===========================================================================
154154
155155 # Convert limits to approximation domain
156- approximation_domain = torch .tensor (limits , dtype = torch .float64 )
156+ approximation_domain = torch .tensor (limits , dtype = torch .float32 )
157157
158158 labels = input_params
159159
@@ -181,7 +181,7 @@ def ln_posterior_torch_maybe(theta_torch):
181181 if False :
182182 hm .logs .info_log ("Computing evidence from TT cores..." )
183183 reduced_cores = []
184- evidence = torch .eye (1 , dtype = torch .float64 , device = device ) # Initialize evidence as 1
184+ evidence = torch .eye (1 , dtype = torch .float32 , device = device ) # Initialize evidence as 1
185185 for k in range (ndim ):
186186 core_k = dirt .sirts [0 ].ftt .tt .cores [k ]
187187 reduced_core_k = core_k .sum (dim = 1 )** 2 # Second dimension sum
You can’t perform that action at this time.
0 commit comments