|
35 | 35 | from .objectives import Objective, ObjectiveList |
36 | 36 | from .plans import default_acquisition_plan |
37 | 37 |
|
| 38 | +logger = logging.getLogger("maria") |
| 39 | + |
38 | 40 | warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning) |
39 | 41 |
|
40 | 42 | mpl.rc("image", cmap="coolwarm") |
@@ -382,7 +384,7 @@ def tell( |
382 | 384 | t0 = ttime.monotonic() |
383 | 385 | train_model(obj.model) |
384 | 386 | if self.verbose: |
385 | | - print(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms") |
| 387 | + logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms") |
386 | 388 |
|
387 | 389 | else: |
388 | 390 | train_model(obj.model, hypers=cached_hypers) |
@@ -432,7 +434,7 @@ def learn( |
432 | 434 |
|
433 | 435 | for i in range(iterations): |
434 | 436 | if self.verbose: |
435 | | - print(f"running iteration {i + 1} / {iterations}") |
| 437 | + logger.info(f"running iteration {i + 1} / {iterations}") |
436 | 438 | for single_acqf in np.atleast_1d(acqf): |
437 | 439 | res = self.ask(n=n, acqf=single_acqf, upsample=upsample, route=route, **acqf_kwargs) |
438 | 440 | new_table = yield from self.acquire(res["points"]) |
@@ -761,7 +763,7 @@ def _train_all_models(self, **kwargs): |
761 | 763 | train_model(obj.validity_conjugate_model) |
762 | 764 |
|
763 | 765 | if self.verbose: |
764 | | - print(f"trained models in {ttime.monotonic() - t0:.01f} seconds") |
| 766 | + logger.info(f"trained models in {ttime.monotonic() - t0:.01f} seconds") |
765 | 767 |
|
766 | 768 | self.n_last_trained = len(self._table) |
767 | 769 |
|
@@ -861,15 +863,11 @@ def _set_hypers(self, hypers): |
861 | 863 | self.validity_constraint.load_state_dict(hypers["validity_constraint"]) |
862 | 864 |
|
863 | 865 | def constraint(self, x): |
864 | | - p = torch.ones(x.shape[:-1]) |
| 866 | + log_p = torch.zeros(x.shape[:-1]) |
865 | 867 | for obj in self.objectives(active=True): |
866 | | - # if the constraint is non-trivial |
867 | | - if obj.constraint is not None: |
868 | | - p *= obj.constraint_probability(x) |
869 | | - # if the validity constaint is non-trivial |
870 | | - if obj.validity_conjugate_model is not None: |
871 | | - p *= obj.validity_constraint(x) |
872 | | - return p # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1) |
| 868 | + log_p += obj.log_total_constraint(x) |
| 869 | + |
| 870 | + return log_p.exp() # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1) |
873 | 871 |
|
874 | 872 | @property |
875 | 873 | def hypers(self) -> dict: |
|
0 commit comments