Skip to content

Commit b5e3267

Browse files
authored
Adds casting to the default dtype in turbo (#60)
* Adds casting to the default dtype in turbo * Moves the default dtype to torch's * Forgot to move one tensor
1 parent f07c60a commit b5e3267

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from poli_baselines.core.step_by_step_solver import StepByStepSolver
2222

2323
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24-
DEFAULT_DTYPE = torch.double
24+
DEFAULT_DTYPE = torch.get_default_dtype()
2525

2626

2727
NUM_RESTARTS = 10
@@ -72,8 +72,12 @@ def from_turbo(X):
7272

7373
self.device = device
7474
self.to_turbo, self.from_turbo = make_transforms()
75-
self.X_turbo = torch.tensor(self.to_turbo(x0)).to(self.device)
76-
self.Y_turbo = torch.tensor(y0).to(self.device)
75+
self.X_turbo = (
76+
torch.tensor(self.to_turbo(x0))
77+
.to(self.device)
78+
.to(torch.get_default_dtype())
79+
)
80+
self.Y_turbo = torch.tensor(y0).to(self.device).to(torch.get_default_dtype())
7781
self.batch_size = 1
7882
dim = x0.shape[1]
7983
self.state = TurboState(dim, batch_size=self.batch_size)
@@ -218,7 +222,7 @@ def generate_batch(
218222
mask[ind, torch.randint(0, dim - 1, size=(len(ind),), device=device)] = 1
219223

220224
# Create candidate points from the perturbations and the mask
221-
X_cand = x_center.expand(n_candidates, dim).clone().to(device)
225+
X_cand = x_center.expand(n_candidates, dim).clone().to(device).to(DEFAULT_DTYPE)
222226
X_cand[mask] = pert[mask]
223227

224228
# Sample on the candidate points

0 commit comments

Comments
 (0)