Skip to content

Commit 3135a7b

Browse files
mpolson64facebook-github-bot
authored andcommitted
Back out "Add support for missing tasks in mtgp" (#3004)
Summary: X-link: facebookexternal/botorch_fb#25 Pull Request resolved: #3004 X-link: facebook/Ax#4261 Original commit changeset: f92a49fb4622 Original Phabricator Diff: D79812024 Same motivation as D81695384, will be cleaned up after Ax 1.1.1 release Reviewed By: paschai Differential Revision: D81784749 fbshipit-source-id: 5f1865d1a8dd795caa703225cb4d67fcd8059f86
1 parent d2cf3e9 commit 3135a7b

File tree

8 files changed

+109
-319
lines changed

8 files changed

+109
-319
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def __init__(
227227
outcome_transform: OutcomeTransform | None = None,
228228
input_transform: InputTransform | None = None,
229229
pyro_model: MultitaskSaasPyroModel | None = None,
230-
validate_task_values: bool = True,
231230
) -> None:
232231
r"""Initialize the fully Bayesian multi-task GP model.
233232
@@ -252,9 +251,6 @@ def __init__(
252251
in the model's forward pass.
253252
pyro_model: Optional `PyroModel` that has the same signature as
254253
`MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
255-
validate_task_values: If True, validate that the task values supplied in the
256-
input are expected tasks values. If false, unexpected task values
257-
will be mapped to the first output_task if supplied.
258254
"""
259255
if not (
260256
train_X.ndim == train_Y.ndim == 2
@@ -292,19 +288,22 @@ def __init__(
292288
# set on `self` below, it will be applied to the posterior in the
293289
# `posterior` method of `MultiTaskGP`.
294290
outcome_transform=None,
295-
all_tasks=all_tasks,
296-
validate_task_values=validate_task_values,
297291
)
292+
if all_tasks is not None and self._expected_task_values != set(all_tasks):
293+
raise NotImplementedError(
294+
"The `all_tasks` argument is not supported by SAAS MTGP. "
295+
f"The training data includes tasks {self._expected_task_values}, "
296+
f"got {all_tasks=}."
297+
)
298298
self.to(train_X)
299+
299300
self.mean_module = None
300301
self.covar_module = None
301302
self.likelihood = None
302303
if pyro_model is None:
303304
pyro_model = MultitaskSaasPyroModel()
304-
# apply task_mapper
305-
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
306305
pyro_model.set_inputs(
307-
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
306+
train_X=transformed_X,
308307
train_Y=train_Y,
309308
train_Yvar=train_Yvar,
310309
task_feature=task_feature,

botorch/models/gpytorch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,39 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
802802
"long-format" multi-task GP in the style of `MultiTaskGP`.
803803
"""
804804

805+
def _map_tasks(self, task_values: Tensor) -> Tensor:
806+
"""Map raw task values to the task indices used by the model.
807+
808+
Args:
809+
task_values: A tensor of task values.
810+
811+
Returns:
812+
A tensor of task indices with the same shape as the input
813+
tensor.
814+
"""
815+
if self._task_mapper is None:
816+
if not (
817+
torch.all(0 <= task_values) and torch.all(task_values < self.num_tasks)
818+
):
819+
raise ValueError(
820+
"Expected all task features in `X` to be between 0 and "
821+
f"self.num_tasks - 1. Got {task_values}."
822+
)
823+
else:
824+
task_values = task_values.long()
825+
826+
unexpected_task_values = set(task_values.unique().tolist()).difference(
827+
self._expected_task_values
828+
)
829+
if len(unexpected_task_values) > 0:
830+
raise ValueError(
831+
"Received invalid raw task values. Expected raw value to be in"
832+
f" {self._expected_task_values}, but got unexpected task values:"
833+
f" {unexpected_task_values}."
834+
)
835+
task_values = self._task_mapper[task_values]
836+
return task_values
837+
805838
def _apply_noise(
806839
self,
807840
X: Tensor,

botorch/models/multitask.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def __init__(
115115
all_tasks: list[int] | None = None,
116116
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
117117
input_transform: InputTransform | None = None,
118-
validate_task_values: bool = True,
119118
) -> None:
120119
r"""Multi-Task GP model using an ICM kernel.
121120
@@ -158,9 +157,6 @@ def __init__(
158157
instantiation of the model.
159158
input_transform: An input transform that is applied in the model's
160159
forward pass.
161-
validate_task_values: If True, validate that the task values supplied in the
162-
input are expected tasks values. If false, unexpected task values
163-
will be mapped to the first output_task if supplied.
164160
165161
Example:
166162
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -193,7 +189,7 @@ def __init__(
193189
"This is not allowed as it will lead to errors during model training."
194190
)
195191
all_tasks = all_tasks or all_tasks_inferred
196-
self.num_tasks = len(all_tasks_inferred)
192+
self.num_tasks = len(all_tasks)
197193
if outcome_transform == DEFAULT:
198194
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
199195
if outcome_transform is not None:
@@ -263,61 +259,19 @@ def __init__(
263259

264260
self.covar_module = data_covar_module * task_covar_module
265261
task_mapper = get_task_value_remapping(
266-
observed_task_values=torch.tensor(
267-
all_tasks_inferred, dtype=torch.long, device=train_X.device
268-
),
269-
all_task_values=torch.tensor(
270-
sorted(all_tasks), dtype=torch.long, device=train_X.device
262+
task_values=torch.tensor(
263+
all_tasks, dtype=torch.long, device=train_X.device
271264
),
272265
dtype=train_X.dtype,
273-
default_task_value=None if output_tasks is None else output_tasks[0],
274266
)
275267
self.register_buffer("_task_mapper", task_mapper)
276-
self._expected_task_values = set(all_tasks_inferred)
268+
self._expected_task_values = set(all_tasks)
277269
if input_transform is not None:
278270
self.input_transform = input_transform
279271
if outcome_transform is not None:
280272
self.outcome_transform = outcome_transform
281-
self._validate_task_values = validate_task_values
282273
self.to(train_X)
283274

284-
def _map_tasks(self, task_values: Tensor) -> Tensor:
285-
"""Map raw task values to the task indices used by the model.
286-
287-
Args:
288-
task_values: A tensor of task values.
289-
290-
Returns:
291-
A tensor of task indices with the same shape as the input
292-
tensor.
293-
"""
294-
long_task_values = task_values.long()
295-
if self._validate_task_values:
296-
if self._task_mapper is None:
297-
if not (
298-
torch.all(0 <= task_values)
299-
and torch.all(task_values < self.num_tasks)
300-
):
301-
raise ValueError(
302-
"Expected all task features in `X` to be between 0 and "
303-
f"self.num_tasks - 1. Got {task_values}."
304-
)
305-
else:
306-
unexpected_task_values = set(
307-
long_task_values.unique().tolist()
308-
).difference(self._expected_task_values)
309-
if len(unexpected_task_values) > 0:
310-
raise ValueError(
311-
"Received invalid raw task values. Expected raw value to be in"
312-
f" {self._expected_task_values}, but got unexpected task"
313-
f" values: {unexpected_task_values}."
314-
)
315-
task_values = self._task_mapper[long_task_values]
316-
elif self._task_mapper is not None:
317-
task_values = self._task_mapper[long_task_values]
318-
319-
return task_values
320-
321275
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
322276
r"""Extracts features before task feature, task indices, and features after
323277
the task feature.
@@ -330,7 +284,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
330284
3-element tuple containing
331285
332286
- A `q x d` or `b x q x d` tensor with features before the task feature
333-
- A `q` or `b x q x 1` tensor with mapped task indices
287+
- A `q` or `b x q` tensor with mapped task indices
334288
- A `q x d` or `b x q x d` tensor with features after the task feature
335289
"""
336290
batch_shape = x.shape[:-2]
@@ -370,7 +324,7 @@ def get_all_tasks(
370324
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
371325
task_feature = task_feature % (d + 1)
372326
all_tasks = (
373-
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
327+
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
374328
)
375329
return all_tasks, task_feature, d
376330

botorch/models/transforms/outcome.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -511,13 +511,11 @@ class StratifiedStandardize(Standardize):
511511

512512
def __init__(
513513
self,
514+
task_values: Tensor,
514515
stratification_idx: int,
515-
observed_task_values: Tensor,
516-
all_task_values: Tensor,
517516
batch_shape: torch.Size = torch.Size(), # noqa: B008
518517
min_stdv: float = 1e-8,
519-
dtype: torch.dtype = torch.double,
520-
default_task_value: int | None = None,
518+
# dtype: torch.dtype = torch.double,
521519
) -> None:
522520
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
523521
@@ -530,21 +528,13 @@ def __init__(
530528
batch_shape: The batch_shape of the training targets.
531529
min_stddv: The minimum standard deviation for which to perform
532530
standardization (if lower, only de-mean the data).
533-
default_task_value: The default task value that unexpected tasks are
534-
mapped to. This is used in `get_task_value_remapping`.
535-
536531
"""
537532
OutcomeTransform.__init__(self)
538533
self._stratification_idx = stratification_idx
539-
observed_task_values = observed_task_values.unique(sorted=True)
540-
self.strata_mapping = get_task_value_remapping(
541-
observed_task_values=observed_task_values,
542-
all_task_values=all_task_values.unique(sorted=True),
543-
dtype=dtype,
544-
default_task_value=default_task_value,
545-
)
534+
task_values = task_values.unique(sorted=True)
535+
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
546536
if self.strata_mapping is None:
547-
self.strata_mapping = observed_task_values
537+
self.strata_mapping = task_values
548538
n_strata = self.strata_mapping.shape[0]
549539
self._min_stdv = min_stdv
550540
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))

botorch/models/utils/assorted.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -405,20 +405,13 @@ class fantasize(_Flag):
405405
_state: bool = False
406406

407407

408-
def get_task_value_remapping(
409-
observed_task_values: Tensor,
410-
all_task_values: Tensor,
411-
dtype: torch.dtype,
412-
default_task_value: int | None,
413-
) -> Tensor | None:
414-
"""Construct an mapping of observed task values to contiguous int-valued floats.
408+
def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
409+
"""Construct an mapping of discrete task values to contiguous int-valued floats.
415410
416411
Args:
417-
observed_task_values: A sorted long-valued tensor of task values.
418-
all_task_values: A sorted long-valued tensor of task values.
412+
task_values: A sorted long-valued tensor of task values.
419413
dtype: The dtype of the model inputs (e.g. `X`), which the new
420414
task values should have mapped to (e.g. float, double).
421-
default_task_value: The default task value to use for missing task values.
422415
423416
Returns:
424417
A tensor of shape `task_values.max() + 1` that maps task values
@@ -432,31 +425,17 @@ def get_task_value_remapping(
432425
if dtype not in (torch.float, torch.double):
433426
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
434427
task_range = torch.arange(
435-
len(observed_task_values),
436-
dtype=all_task_values.dtype,
437-
device=all_task_values.device,
428+
len(task_values), dtype=task_values.dtype, device=task_values.device
438429
)
439430
mapper = None
440-
441-
if default_task_value is None:
442-
fill_value = float("nan")
443-
else:
444-
mask = observed_task_values == default_task_value
445-
if not mask.any():
446-
fill_value = float("nan")
447-
else:
448-
idx = mask.nonzero().item()
449-
fill_value = task_range[idx]
450-
# if not all tasks are observed or they are not contiguous integers
451-
# then map them to contiguous integers
452-
if not torch.equal(task_range, all_task_values):
431+
if not torch.equal(task_values, task_range):
453432
# Create a tensor that maps task values to new task values.
454433
# The number of tasks should be small, so this should be quite efficient.
455434
mapper = torch.full(
456-
(int(all_task_values.max().item()) + 1,),
457-
fill_value,
435+
(int(task_values.max().item()) + 1,),
436+
float("nan"),
458437
dtype=dtype,
459-
device=all_task_values.device,
438+
device=task_values.device,
460439
)
461-
mapper[observed_task_values] = task_range.to(dtype=dtype)
440+
mapper[task_values] = task_range.to(dtype=dtype)
462441
return mapper

0 commit comments

Comments
 (0)