Skip to content

Commit eba2dce

Browse files
mpolson64facebook-github-bot
authored andcommitted
Back out "Back out "[botorch/ax] Add support for missing tasks in mtgp"" (#3006)
Summary: X-link: facebookexternal/botorch_fb#26 Pull Request resolved: #3006 X-link: facebook/Ax#4266 Original commit changeset: 5f1865d1a8dd Original Phabricator Diff: D81784749 Reviewed By: sdaulton Differential Revision: D81805854 fbshipit-source-id: 6287e74957f4b22bb8558f0a946307e483bc637c
1 parent 1518b30 commit eba2dce

File tree

8 files changed

+319
-109
lines changed

8 files changed

+319
-109
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
outcome_transform: OutcomeTransform | None = None,
228228
input_transform: InputTransform | None = None,
229229
pyro_model: MultitaskSaasPyroModel | None = None,
230+
validate_task_values: bool = True,
230231
) -> None:
231232
r"""Initialize the fully Bayesian multi-task GP model.
232233
@@ -251,6 +252,9 @@ def __init__(
251252
in the model's forward pass.
252253
pyro_model: Optional `PyroModel` that has the same signature as
253254
`MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
255+
validate_task_values: If True, validate that the task values supplied in the
256+
input are expected tasks values. If false, unexpected task values
257+
will be mapped to the first output_task if supplied.
254258
"""
255259
if not (
256260
train_X.ndim == train_Y.ndim == 2
@@ -288,22 +292,19 @@ def __init__(
288292
# set on `self` below, it will be applied to the posterior in the
289293
# `posterior` method of `MultiTaskGP`.
290294
outcome_transform=None,
295+
all_tasks=all_tasks,
296+
validate_task_values=validate_task_values,
291297
)
292-
if all_tasks is not None and self._expected_task_values != set(all_tasks):
293-
raise NotImplementedError(
294-
"The `all_tasks` argument is not supported by SAAS MTGP. "
295-
f"The training data includes tasks {self._expected_task_values}, "
296-
f"got {all_tasks=}."
297-
)
298298
self.to(train_X)
299-
300299
self.mean_module = None
301300
self.covar_module = None
302301
self.likelihood = None
303302
if pyro_model is None:
304303
pyro_model = MultitaskSaasPyroModel()
304+
# apply task_mapper
305+
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
305306
pyro_model.set_inputs(
306-
train_X=transformed_X,
307+
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
307308
train_Y=train_Y,
308309
train_Yvar=train_Yvar,
309310
task_feature=task_feature,

botorch/models/gpytorch.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -802,39 +802,6 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
802802
"long-format" multi-task GP in the style of `MultiTaskGP`.
803803
"""
804804

805-
def _map_tasks(self, task_values: Tensor) -> Tensor:
806-
"""Map raw task values to the task indices used by the model.
807-
808-
Args:
809-
task_values: A tensor of task values.
810-
811-
Returns:
812-
A tensor of task indices with the same shape as the input
813-
tensor.
814-
"""
815-
if self._task_mapper is None:
816-
if not (
817-
torch.all(0 <= task_values) and torch.all(task_values < self.num_tasks)
818-
):
819-
raise ValueError(
820-
"Expected all task features in `X` to be between 0 and "
821-
f"self.num_tasks - 1. Got {task_values}."
822-
)
823-
else:
824-
task_values = task_values.long()
825-
826-
unexpected_task_values = set(task_values.unique().tolist()).difference(
827-
self._expected_task_values
828-
)
829-
if len(unexpected_task_values) > 0:
830-
raise ValueError(
831-
"Received invalid raw task values. Expected raw value to be in"
832-
f" {self._expected_task_values}, but got unexpected task values:"
833-
f" {unexpected_task_values}."
834-
)
835-
task_values = self._task_mapper[task_values]
836-
return task_values
837-
838805
def _apply_noise(
839806
self,
840807
X: Tensor,

botorch/models/multitask.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
all_tasks: list[int] | None = None,
116116
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
117117
input_transform: InputTransform | None = None,
118+
validate_task_values: bool = True,
118119
) -> None:
119120
r"""Multi-Task GP model using an ICM kernel.
120121
@@ -157,6 +158,9 @@ def __init__(
157158
instantiation of the model.
158159
input_transform: An input transform that is applied in the model's
159160
forward pass.
161+
validate_task_values: If True, validate that the task values supplied in the
162+
input are expected tasks values. If false, unexpected task values
163+
will be mapped to the first output_task if supplied.
160164
161165
Example:
162166
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -189,7 +193,7 @@ def __init__(
189193
"This is not allowed as it will lead to errors during model training."
190194
)
191195
all_tasks = all_tasks or all_tasks_inferred
192-
self.num_tasks = len(all_tasks)
196+
self.num_tasks = len(all_tasks_inferred)
193197
if outcome_transform == DEFAULT:
194198
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
195199
if outcome_transform is not None:
@@ -259,19 +263,61 @@ def __init__(
259263

260264
self.covar_module = data_covar_module * task_covar_module
261265
task_mapper = get_task_value_remapping(
262-
task_values=torch.tensor(
263-
all_tasks, dtype=torch.long, device=train_X.device
266+
observed_task_values=torch.tensor(
267+
all_tasks_inferred, dtype=torch.long, device=train_X.device
268+
),
269+
all_task_values=torch.tensor(
270+
sorted(all_tasks), dtype=torch.long, device=train_X.device
264271
),
265272
dtype=train_X.dtype,
273+
default_task_value=None if output_tasks is None else output_tasks[0],
266274
)
267275
self.register_buffer("_task_mapper", task_mapper)
268-
self._expected_task_values = set(all_tasks)
276+
self._expected_task_values = set(all_tasks_inferred)
269277
if input_transform is not None:
270278
self.input_transform = input_transform
271279
if outcome_transform is not None:
272280
self.outcome_transform = outcome_transform
281+
self._validate_task_values = validate_task_values
273282
self.to(train_X)
274283

284+
def _map_tasks(self, task_values: Tensor) -> Tensor:
285+
"""Map raw task values to the task indices used by the model.
286+
287+
Args:
288+
task_values: A tensor of task values.
289+
290+
Returns:
291+
A tensor of task indices with the same shape as the input
292+
tensor.
293+
"""
294+
long_task_values = task_values.long()
295+
if self._validate_task_values:
296+
if self._task_mapper is None:
297+
if not (
298+
torch.all(0 <= task_values)
299+
and torch.all(task_values < self.num_tasks)
300+
):
301+
raise ValueError(
302+
"Expected all task features in `X` to be between 0 and "
303+
f"self.num_tasks - 1. Got {task_values}."
304+
)
305+
else:
306+
unexpected_task_values = set(
307+
long_task_values.unique().tolist()
308+
).difference(self._expected_task_values)
309+
if len(unexpected_task_values) > 0:
310+
raise ValueError(
311+
"Received invalid raw task values. Expected raw value to be in"
312+
f" {self._expected_task_values}, but got unexpected task"
313+
f" values: {unexpected_task_values}."
314+
)
315+
task_values = self._task_mapper[long_task_values]
316+
elif self._task_mapper is not None:
317+
task_values = self._task_mapper[long_task_values]
318+
319+
return task_values
320+
275321
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
276322
r"""Extracts features before task feature, task indices, and features after
277323
the task feature.
@@ -284,7 +330,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
284330
3-element tuple containing
285331
286332
- A `q x d` or `b x q x d` tensor with features before the task feature
287-
- A `q` or `b x q` tensor with mapped task indices
333+
- A `q` or `b x q x 1` tensor with mapped task indices
288334
- A `q x d` or `b x q x d` tensor with features after the task feature
289335
"""
290336
batch_shape = x.shape[:-2]
@@ -324,7 +370,7 @@ def get_all_tasks(
324370
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
325371
task_feature = task_feature % (d + 1)
326372
all_tasks = (
327-
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
373+
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
328374
)
329375
return all_tasks, task_feature, d
330376

botorch/models/transforms/outcome.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,13 @@ class StratifiedStandardize(Standardize):
511511

512512
def __init__(
513513
self,
514-
task_values: Tensor,
515514
stratification_idx: int,
515+
observed_task_values: Tensor,
516+
all_task_values: Tensor,
516517
batch_shape: torch.Size = torch.Size(), # noqa: B008
517518
min_stdv: float = 1e-8,
518-
# dtype: torch.dtype = torch.double,
519+
dtype: torch.dtype = torch.double,
520+
default_task_value: int | None = None,
519521
) -> None:
520522
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
521523
@@ -528,13 +530,21 @@ def __init__(
528530
batch_shape: The batch_shape of the training targets.
529531
min_stddv: The minimum standard deviation for which to perform
530532
standardization (if lower, only de-mean the data).
533+
default_task_value: The default task value that unexpected tasks are
534+
mapped to. This is used in `get_task_value_remapping`.
535+
531536
"""
532537
OutcomeTransform.__init__(self)
533538
self._stratification_idx = stratification_idx
534-
task_values = task_values.unique(sorted=True)
535-
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
539+
observed_task_values = observed_task_values.unique(sorted=True)
540+
self.strata_mapping = get_task_value_remapping(
541+
observed_task_values=observed_task_values,
542+
all_task_values=all_task_values.unique(sorted=True),
543+
dtype=dtype,
544+
default_task_value=default_task_value,
545+
)
536546
if self.strata_mapping is None:
537-
self.strata_mapping = task_values
547+
self.strata_mapping = observed_task_values
538548
n_strata = self.strata_mapping.shape[0]
539549
self._min_stdv = min_stdv
540550
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))

botorch/models/utils/assorted.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,20 @@ class fantasize(_Flag):
405405
_state: bool = False
406406

407407

408-
def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
409-
"""Construct an mapping of discrete task values to contiguous int-valued floats.
408+
def get_task_value_remapping(
409+
observed_task_values: Tensor,
410+
all_task_values: Tensor,
411+
dtype: torch.dtype,
412+
default_task_value: int | None,
413+
) -> Tensor | None:
414+
"""Construct an mapping of observed task values to contiguous int-valued floats.
410415
411416
Args:
412-
task_values: A sorted long-valued tensor of task values.
417+
observed_task_values: A sorted long-valued tensor of task values.
418+
all_task_values: A sorted long-valued tensor of task values.
413419
dtype: The dtype of the model inputs (e.g. `X`), which the new
414420
task values should have mapped to (e.g. float, double).
421+
default_task_value: The default task value to use for missing task values.
415422
416423
Returns:
417424
A tensor of shape `task_values.max() + 1` that maps task values
@@ -425,17 +432,31 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
425432
if dtype not in (torch.float, torch.double):
426433
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
427434
task_range = torch.arange(
428-
len(task_values), dtype=task_values.dtype, device=task_values.device
435+
len(observed_task_values),
436+
dtype=all_task_values.dtype,
437+
device=all_task_values.device,
429438
)
430439
mapper = None
431-
if not torch.equal(task_values, task_range):
440+
441+
if default_task_value is None:
442+
fill_value = float("nan")
443+
else:
444+
mask = observed_task_values == default_task_value
445+
if not mask.any():
446+
fill_value = float("nan")
447+
else:
448+
idx = mask.nonzero().item()
449+
fill_value = task_range[idx]
450+
# if not all tasks are observed or they are not contiguous integers
451+
# then map them to contiguous integers
452+
if not torch.equal(task_range, all_task_values):
432453
# Create a tensor that maps task values to new task values.
433454
# The number of tasks should be small, so this should be quite efficient.
434455
mapper = torch.full(
435-
(int(task_values.max().item()) + 1,),
436-
float("nan"),
456+
(int(all_task_values.max().item()) + 1,),
457+
fill_value,
437458
dtype=dtype,
438-
device=task_values.device,
459+
device=all_task_values.device,
439460
)
440-
mapper[task_values] = task_range.to(dtype=dtype)
461+
mapper[observed_task_values] = task_range.to(dtype=dtype)
441462
return mapper

0 commit comments

Comments
 (0)