Skip to content

Commit 51da51b

Browse files
Switch to single precision
1 parent 7bbe235 commit 51da51b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/cosmo_tt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
print("JAX devices:", jax.devices())
1616

1717
# Set double precision
18-
torch.set_default_dtype(torch.float64)
19-
jax.config.update("jax_enable_x64", True)
18+
torch.set_default_dtype(torch.float32)
19+
jax.config.update("jax_enable_x64", False)
2020

2121

2222
def run_cosmo_tt_example(
@@ -153,7 +153,7 @@ def ln_posterior_torch_maybe(theta_torch):
153153
# ===========================================================================
154154

155155
# Convert limits to approximation domain
156-
approximation_domain = torch.tensor(limits, dtype=torch.float64)
156+
approximation_domain = torch.tensor(limits, dtype=torch.float32)
157157

158158
labels = input_params
159159

@@ -181,7 +181,7 @@ def ln_posterior_torch_maybe(theta_torch):
181181
if False:
182182
hm.logs.info_log("Computing evidence from TT cores...")
183183
reduced_cores = []
184-
evidence = torch.eye(1, dtype=torch.float64, device=device) # Initialize evidence as 1
184+
evidence = torch.eye(1, dtype=torch.float32, device=device) # Initialize evidence as 1
185185
for k in range(ndim):
186186
core_k = dirt.sirts[0].ftt.tt.cores[k]
187187
reduced_core_k = core_k.sum(dim=1)**2 # Second dimension sum

0 commit comments

Comments
 (0)