Skip to content

Commit 9fc28ea

Browse files
authored
Merge pull request #79 from thomaswmorris/logging
Add logging
2 parents 1e747ea + 9771993 commit 9fc28ea

File tree

5 files changed

+48
-18
lines changed

5 files changed

+48
-18
lines changed

src/blop/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1+
import logging
2+
13
from . import utils # noqa F401
24
from ._version import __version__, __version_tuple__ # noqa: F401
35
from .agent import Agent # noqa F401
46
from .dofs import DOF # noqa F401
57
from .objectives import Objective # noqa F401
8+
9+
logging.basicConfig(
10+
level=logging.INFO,
11+
format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
12+
datefmt="%Y-%m-%d %H:%M:%S",
13+
)
14+
15+
logger = logging.getLogger("maria")

src/blop/agent.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from .objectives import Objective, ObjectiveList
3636
from .plans import default_acquisition_plan
3737

38+
logger = logging.getLogger("maria")
39+
3840
warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning)
3941

4042
mpl.rc("image", cmap="coolwarm")
@@ -382,7 +384,7 @@ def tell(
382384
t0 = ttime.monotonic()
383385
train_model(obj.model)
384386
if self.verbose:
385-
print(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")
387+
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")
386388

387389
else:
388390
train_model(obj.model, hypers=cached_hypers)
@@ -432,7 +434,7 @@ def learn(
432434

433435
for i in range(iterations):
434436
if self.verbose:
435-
print(f"running iteration {i + 1} / {iterations}")
437+
logger.info(f"running iteration {i + 1} / {iterations}")
436438
for single_acqf in np.atleast_1d(acqf):
437439
res = self.ask(n=n, acqf=single_acqf, upsample=upsample, route=route, **acqf_kwargs)
438440
new_table = yield from self.acquire(res["points"])
@@ -761,7 +763,7 @@ def _train_all_models(self, **kwargs):
761763
train_model(obj.validity_conjugate_model)
762764

763765
if self.verbose:
764-
print(f"trained models in {ttime.monotonic() - t0:.01f} seconds")
766+
logger.info(f"trained models in {ttime.monotonic() - t0:.01f} seconds")
765767

766768
self.n_last_trained = len(self._table)
767769

@@ -861,15 +863,11 @@ def _set_hypers(self, hypers):
861863
self.validity_constraint.load_state_dict(hypers["validity_constraint"])
862864

863865
def constraint(self, x):
864-
p = torch.ones(x.shape[:-1])
866+
log_p = torch.zeros(x.shape[:-1])
865867
for obj in self.objectives(active=True):
866-
# if the constraint is non-trivial
867-
if obj.constraint is not None:
868-
p *= obj.constraint_probability(x)
869-
# if the validity constaint is non-trivial
870-
if obj.validity_conjugate_model is not None:
871-
p *= obj.validity_constraint(x)
872-
return p # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1)
868+
log_p += obj.log_total_constraint(x)
869+
870+
return log_p.exp() # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1)
873871

874872
@property
875873
def hypers(self) -> dict:

src/blop/bayesian/models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,22 @@
66
from . import kernels
77

88

9-
def train_model(model, hypers=None, **kwargs):
9+
def train_model(model, hypers=None, max_fails=4, **kwargs):
1010
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
11-
if hypers is not None:
12-
model.load_state_dict(hypers)
13-
else:
14-
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
15-
model.trained = True
11+
fails = 0
12+
while True:
13+
try:
14+
if hypers is not None:
15+
model.load_state_dict(hypers)
16+
else:
17+
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
18+
model.trained = True
19+
return
20+
except Exception as e:
21+
if fails < max_fails:
22+
fails += 1
23+
else:
24+
raise e
1625

1726

1827
def construct_single_task_model(X, y, skew_dims=None, min_noise=1e-6, max_noise=1e0):

src/blop/objectives.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ def constrain(self, y):
155155
else:
156156
return np.array([value in self.constraint for value in np.atleast_1d(y)])
157157

158+
def log_total_constraint(self, x):
159+
160+
log_p = 0
161+
# if you have a constraint
162+
if self.constraint is not None:
163+
log_p += self.constraint_probability(x).log()
164+
165+
# if the validity constaint is non-trivial
166+
if self.validity_conjugate_model is not None:
167+
log_p += self.validity_constraint(x).log()
168+
169+
return log_p
170+
158171
@property
159172
def _trust_domain(self):
160173
if self.trust_domain is None:

src/blop/plotting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0):
3333
test_model_inputs = agent.dofs(active=True).transform(test_inputs)
3434

3535
for obj_index, obj in enumerate(fitness_objs):
36-
obj_values = agent.train_targets()(obj.name).numpy()
36+
obj_values = agent.train_targets()[obj.name].numpy()
3737

3838
color = DEFAULT_COLOR_LIST[obj_index]
3939

0 commit comments

Comments
 (0)