Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion bofire/data_models/acquisition_functions/acquisition_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Dict, Literal, Optional
from typing import Annotated, Dict, List, Literal, Optional, Union

from pydantic import Field, PositiveFloat

Expand All @@ -18,6 +18,10 @@ class MultiObjectiveAcquisitionFunction(AcquisitionFunction):
type: str


class MultiFideltyAcquisitionFunction(AcquisitionFunction):
type: str


class qNEI(SingleObjectiveAcquisitionFunction):
type: Literal["qNEI"] = "qNEI"
prune_baseline: bool = True
Expand Down Expand Up @@ -87,3 +91,17 @@ class qNegIntPosVar(SingleObjectiveAcquisitionFunction):
type: Literal["qNegIntPosVar"] = "qNegIntPosVar"
n_mc_samples: IntPowerOfTwo = 512
weights: Optional[Dict[str, PositiveFloat]] = Field(default_factory=lambda: None)


class qMFMES(MultiFideltyAcquisitionFunction):
type: Literal["qMFMES"] = "qMFMES"
num_fantasies: IntPowerOfTwo = 16
num_mv_samples: int = 10
num_y_samples: IntPowerOfTwo = 128
fidelity_costs: list[float]


class qMFVariance(MultiFideltyAcquisitionFunction):
type: Literal["qMFVariance"] = "qMFVariance"
beta: Annotated[float, Field(ge=0)] = 0.2
fidelity_thresholds: Union[List[float], float] = 0.1
6 changes: 6 additions & 0 deletions bofire/data_models/acquisition_functions/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from bofire.data_models.acquisition_functions.acquisition_function import (
AcquisitionFunction,
MultiFideltyAcquisitionFunction,
MultiObjectiveAcquisitionFunction,
SingleObjectiveAcquisitionFunction,
qEHVI,
Expand All @@ -10,6 +11,8 @@
qLogEI,
qLogNEHVI,
qLogNEI,
qMFMES,
qMFVariance,
qNegIntPosVar,
qNEHVI,
qNEI,
Expand All @@ -23,6 +26,7 @@
AcquisitionFunction,
SingleObjectiveAcquisitionFunction,
MultiObjectiveAcquisitionFunction,
MultiFideltyAcquisitionFunction,
]

AnyAcquisitionFunction = Union[
Expand Down Expand Up @@ -53,3 +57,5 @@
AnyMultiObjectiveAcquisitionFunction = Union[qEHVI, qLogEHVI, qNEHVI, qLogNEHVI]

AnyActiveLearningAcquisitionFunction = qNegIntPosVar

AnyMultiFidelityAcquisitionFunction = Union[qMFMES, qMFVariance]
20 changes: 14 additions & 6 deletions bofire/data_models/strategies/predictives/multi_fidelity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import List, Literal, Union
from typing import Literal

from pydantic import model_validator
from pydantic import Field, model_validator

from bofire.data_models.acquisition_functions.api import (
AnyMultiFidelityAcquisitionFunction,
qMFVariance,
)
from bofire.data_models.domain.api import Domain, Outputs
from bofire.data_models.features.api import TaskInput
from bofire.data_models.strategies.predictives.sobo import SoboStrategy
Expand All @@ -11,20 +15,24 @@
class MultiFidelityStrategy(SoboStrategy):
type: Literal["MultiFidelityStrategy"] = "MultiFidelityStrategy"

fidelity_thresholds: Union[List[float], float] = 0.1
fidelity_acquisition_function: AnyMultiFidelityAcquisitionFunction = Field(
default_factory=lambda: qMFVariance(),
)

@model_validator(mode="after")
def validate_tasks_and_fidelity_thresholds(self):
"""Ensures that there is one threshold per fidelity"""
task_input, *_ = self.domain.inputs.get(includes=TaskInput, exact=True)
num_tasks = len(task_input.categories) # type: ignore
fid_acqf = self.fidelity_acquisition_function

if (
isinstance(self.fidelity_thresholds, list)
and len(self.fidelity_thresholds) != num_tasks
isinstance(fid_acqf, qMFVariance)
and isinstance(fid_acqf.fidelity_thresholds, list)
and len(fid_acqf.fidelity_thresholds) != num_tasks
):
raise ValueError(
f"The number of tasks should be equal to the number of fidelity thresholds (got {num_tasks} tasks, {len(self.fidelity_thresholds)} thresholds)."
f"The number of tasks should be equal to the number of fidelity thresholds (got {num_tasks} tasks, {len(fid_acqf.fidelity_thresholds)} thresholds)."
)

return self
Expand Down
Loading
Loading