Skip to content

Commit 8cdc595

Browse files
Qing Fengfacebook-github-bot
authored andcommitted
implement ContextualDataset (#2066)
Summary: Pull Request resolved: #2066 Implement contextual dataset for fitting contextual GP. If one single dataset is passed, we construct the data for fitting LCEA GP; if multiple datasets are given, each should correspond a context breakdown of one outcome and is expected to be combined to be fitted with LCEMGP Reviewed By: bletham Differential Revision: D50440957 fbshipit-source-id: 55b06d3bc3c739eb47978c4303d38be0cc286dbf
1 parent e1cb934 commit 8cdc595

File tree

2 files changed

+324
-0
lines changed

2 files changed

+324
-0
lines changed

botorch/utils/datasets.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from __future__ import annotations
1010

11+
import collections
12+
1113
import warnings
1214
from typing import Any, Dict, List, Optional, Union
1315

@@ -476,3 +478,152 @@ def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDatas
476478
],
477479
outcome_names=[outcome_name],
478480
)
481+
482+
483+
class ContextualDataset(SupervisedDataset):
484+
"""This is a contextual dataset that is constructed from either a single
485+
dateset containing overall outcome or a list of datasets that each corresponds
486+
to a context breakdown.
487+
"""
488+
489+
def __init__(
490+
self,
491+
datasets: List[SupervisedDataset],
492+
parameter_decomposition: Dict[str, List[str]],
493+
context_buckets: List[str],
494+
metric_decomposition: Optional[Dict[str, List[str]]] = None,
495+
):
496+
"""Construct a `ContextualDataset`.
497+
498+
Args:
499+
datasets: A list of the datasets of individual tasks. Each dataset
500+
is expected to contain data for only one outcome.
501+
parameter_decomposition: Dict from context name to list of indices
502+
of X corresponding to that context.
503+
context_buckets: List of the context names in the order of dataset
504+
in datasets corresponding to each context outcome.
505+
metric_decomposition: Context breakdown metrics. Keys are context names.
506+
Values are the lists of metric names belonging to the context:
507+
{'context1': ['m1_c1'], 'context2': ['m1_c2'],}.
508+
"""
509+
self.datasets: Dict[str, SupervisedDataset] = {
510+
ds.outcome_names[0]: ds for ds in datasets
511+
}
512+
self.feature_names = datasets[0].feature_names
513+
self.outcome_names = list(self.datasets.keys())
514+
self.parameter_decomposition = parameter_decomposition
515+
self.context_buckets = context_buckets
516+
self.metric_decomposition = metric_decomposition
517+
self._validate_datasets(
518+
datasets=datasets, metric_decomposition=metric_decomposition
519+
)
520+
# order the dataset based on context bucket
521+
self.outcome_names = self._sort_outcome_names()
522+
523+
@property
524+
def X(self) -> Tensor:
525+
return self.datasets[self.outcome_names[0]].X
526+
527+
@property
528+
def Y(self) -> Tensor:
529+
"""Concatenates the Ys from the child datasets to create the Y expected
530+
by LCEM model if there are multiple datasets; Or return the Y expected
531+
by LCEA model if there is only one dataset.
532+
"""
533+
if len(self.datasets) == 1:
534+
# use LCEA model
535+
return self.datasets[self.outcome_names[0]].Y
536+
else:
537+
return torch.cat(
538+
[self.datasets[outcome_name].Y for outcome_name in self.outcome_names],
539+
dim=-1,
540+
)
541+
542+
@property
543+
def Yvar(self) -> Tensor:
544+
"""Concatenates the Yvars from the child datasets to create the Y expected
545+
by LCEM model if there are multiple datasets; Or return the Yvar expected
546+
by LCEA model if there is only one dataset.
547+
"""
548+
if len(self.datasets) == 1:
549+
# use LCEA model
550+
return self.datasets[self.outcome_names[0]].Yvar
551+
else:
552+
return torch.cat(
553+
[
554+
self.datasets[outcome_name].Yvar
555+
for outcome_name in self.outcome_names
556+
],
557+
dim=-1,
558+
)
559+
560+
def _sort_outcome_names(self) -> List[str]:
561+
"""Sort the outcome names according to the order of context buckets."""
562+
outcome_names = list(self.datasets.keys())
563+
if len(outcome_names) == 1:
564+
return outcome_names
565+
else:
566+
context_outcome_map = {}
567+
for context in self.context_buckets:
568+
for outcome_name in outcome_names:
569+
if outcome_name in self.metric_decomposition[context]:
570+
if context_outcome_map.get(context, None) is not None:
571+
raise ValueError(
572+
f"{context} bucket contains mutltiple outcomes"
573+
)
574+
context_outcome_map[context] = outcome_name
575+
return [context_outcome_map[context] for context in self.context_buckets]
576+
577+
def _validate_datasets(
578+
self,
579+
datasets: List[SupervisedDataset],
580+
metric_decomposition: Optional[Dict[str, List[str]]] = None,
581+
) -> None:
582+
"""Validation of given datasets.
583+
1. each dataset has same X.
584+
2. metric_decomposition is not None if there are multiple datasets.
585+
3. metric_decomposition contains all the outcomes in datasets.
586+
4. value keys of parameter decomposition and the keys of
587+
metric_decomposition match context buckets.
588+
"""
589+
X = datasets[0].X
590+
for dataset in datasets:
591+
if torch.equal(X, dataset.X) is not True:
592+
raise InputDataError("Require same X for context buckets")
593+
594+
if len(datasets) > 1:
595+
if metric_decomposition is None:
596+
raise InputDataError(
597+
"metric_decomposition must be provided when there are"
598+
+ " multiple datasets."
599+
)
600+
else:
601+
if metric_decomposition is not None:
602+
raise InputDataError(
603+
"metric_decomposition is redundant when there is one "
604+
+ "dataset for overall outcome."
605+
)
606+
607+
if collections.Counter(
608+
list(self.parameter_decomposition.keys())
609+
) != collections.Counter(self.context_buckets):
610+
raise InputDataError(
611+
"Keys of parameter decomposition and context buckets do not match."
612+
)
613+
614+
if metric_decomposition is not None:
615+
if collections.Counter(
616+
list(self.metric_decomposition.keys())
617+
) != collections.Counter(self.context_buckets):
618+
raise InputDataError(
619+
"Keys of metric decomposition and context buckets do not match."
620+
)
621+
622+
all_metrics = []
623+
for m in metric_decomposition.values():
624+
all_metrics.extend(m)
625+
for outcome in self.outcome_names:
626+
if outcome not in all_metrics:
627+
raise InputDataError(
628+
f"{outcome} is missing in metric_decomposition."
629+
)

test/utils/test_datasets.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from botorch.exceptions.errors import InputDataError, UnsupportedError
1111
from botorch.utils.containers import DenseContainer, SliceContainer
1212
from botorch.utils.datasets import (
13+
ContextualDataset,
1314
FixedNoiseDataset,
1415
MultiTaskDataset,
1516
RankingDataset,
@@ -335,3 +336,175 @@ def test_multi_task(self):
335336
task_feature_index=-1,
336337
target_task_value=0,
337338
)
339+
340+
def test_contextual_datasets(self):
341+
num_contexts = 3
342+
feature_names = [f"x_c{i}" for i in range(num_contexts)]
343+
parameter_decomposition = {
344+
f"context_{i}": [f"x_c{i}"] for i in range(num_contexts)
345+
}
346+
context_buckets = list(parameter_decomposition.keys())
347+
context_outcome_list = [f"y:context_{i}" for i in range(num_contexts)]
348+
metric_decomposition = {f"{c}": [f"y:{c}"] for c in context_buckets}
349+
350+
# test construction of agg outcome
351+
dataset_list1 = [
352+
make_dataset(
353+
d=1 * num_contexts,
354+
has_yvar=True,
355+
feature_names=feature_names,
356+
outcome_names=["y"],
357+
)
358+
]
359+
context_dt = ContextualDataset(
360+
datasets=dataset_list1,
361+
parameter_decomposition=parameter_decomposition,
362+
context_buckets=context_buckets,
363+
)
364+
self.assertEqual(len(context_dt.datasets), len(dataset_list1))
365+
self.assertListEqual(context_dt.context_buckets, context_buckets)
366+
self.assertListEqual(context_dt.outcome_names, ["y"])
367+
self.assertListEqual(context_dt.feature_names, feature_names)
368+
self.assertIs(context_dt.datasets["y"], dataset_list1[0])
369+
self.assertIs(context_dt.X, dataset_list1[0].X)
370+
self.assertIs(context_dt.Y, dataset_list1[0].Y)
371+
self.assertIs(context_dt.Yvar, dataset_list1[0].Yvar)
372+
373+
# test construction of context outcome
374+
dataset_list2 = [
375+
make_dataset(
376+
d=1 * num_contexts,
377+
has_yvar=True,
378+
feature_names=feature_names,
379+
outcome_names=[context_outcome_list[0]],
380+
)
381+
]
382+
for m in context_outcome_list[1:]:
383+
dataset_list2.append(
384+
SupervisedDataset(
385+
X=dataset_list2[0].X,
386+
Y=rand(dataset_list2[0].Y.size()),
387+
Yvar=rand(dataset_list2[0].Yvar.size()),
388+
feature_names=feature_names,
389+
outcome_names=[m],
390+
)
391+
)
392+
context_dt = ContextualDataset(
393+
datasets=dataset_list2,
394+
parameter_decomposition=parameter_decomposition,
395+
context_buckets=context_buckets,
396+
metric_decomposition=metric_decomposition,
397+
)
398+
self.assertEqual(len(context_dt.datasets), len(dataset_list2))
399+
self.assertListEqual(context_dt.context_buckets, context_buckets)
400+
self.assertListEqual(context_dt.outcome_names, context_outcome_list)
401+
self.assertListEqual(context_dt.feature_names, feature_names)
402+
self.assertTrue(torch.equal(context_dt.X, dataset_list2[-1].X))
403+
self.assertEqual(context_dt.Y.shape[-1], len(context_outcome_list))
404+
self.assertEqual(context_dt.Yvar.shape[-1], len(context_outcome_list))
405+
for dt in dataset_list2:
406+
self.assertIs(context_dt.datasets[dt.outcome_names[0]], dt)
407+
408+
# test the ordering via context buckets
409+
context_dt_reverse = ContextualDataset(
410+
datasets=dataset_list2,
411+
parameter_decomposition=parameter_decomposition,
412+
context_buckets=context_buckets[::-1], # reverse order
413+
metric_decomposition=metric_decomposition,
414+
)
415+
self.assertListEqual(
416+
context_dt_reverse.outcome_names, context_outcome_list[::-1]
417+
)
418+
self.assertTrue(
419+
torch.equal(context_dt.Y, torch.flip(context_dt_reverse.Y, (1,)))
420+
)
421+
self.assertTrue(
422+
torch.equal(context_dt.Yvar, torch.flip(context_dt_reverse.Yvar, (1,)))
423+
)
424+
425+
# test dataset validation
426+
wrong_metric_decomposition = {
427+
f"{c}": [f"y:{c}"] for c in context_buckets if c != "context_0"
428+
}
429+
wrong_metric_decomposition["context_0"] = ["y:context_0", "y:context_1"]
430+
with self.assertRaisesRegex(
431+
ValueError, "context_0 bucket contains mutltiple outcomes"
432+
):
433+
ContextualDataset(
434+
datasets=dataset_list2,
435+
parameter_decomposition=parameter_decomposition,
436+
context_buckets=context_buckets,
437+
metric_decomposition=wrong_metric_decomposition,
438+
)
439+
440+
with self.assertRaisesRegex(
441+
InputDataError, "Require same X for context buckets"
442+
):
443+
ContextualDataset(
444+
datasets=[
445+
make_dataset(d=num_contexts, outcome_names=[m])
446+
for m in context_outcome_list
447+
],
448+
parameter_decomposition=parameter_decomposition,
449+
context_buckets=context_buckets,
450+
)
451+
452+
with self.assertRaisesRegex(
453+
InputDataError,
454+
"metric_decomposition must be provided when there are multiple datasets.",
455+
):
456+
ContextualDataset(
457+
datasets=dataset_list2,
458+
parameter_decomposition=parameter_decomposition,
459+
context_buckets=context_buckets,
460+
)
461+
462+
with self.assertRaisesRegex(
463+
InputDataError,
464+
"metric_decomposition is redundant when there is "
465+
+ "one dataset for overall outcome.",
466+
):
467+
ContextualDataset(
468+
datasets=dataset_list1,
469+
parameter_decomposition=parameter_decomposition,
470+
context_buckets=context_buckets,
471+
metric_decomposition=metric_decomposition,
472+
)
473+
474+
with self.assertRaisesRegex(
475+
InputDataError,
476+
"Keys of parameter decomposition and context buckets do not match.",
477+
):
478+
ContextualDataset(
479+
datasets=dataset_list1,
480+
parameter_decomposition=parameter_decomposition,
481+
context_buckets=["context_0", "context_1"],
482+
)
483+
484+
with self.assertRaisesRegex(
485+
InputDataError,
486+
"Keys of metric decomposition and context buckets do not match.",
487+
):
488+
ContextualDataset(
489+
datasets=dataset_list2,
490+
parameter_decomposition=parameter_decomposition,
491+
context_buckets=context_buckets,
492+
metric_decomposition={
493+
f"{c}": [f"y:{c}"] for c in context_buckets if c != "context_0"
494+
},
495+
)
496+
497+
wrong_metric_decomposition = {
498+
f"{c}": [f"y:{c}"] for c in context_buckets if c != "context_0"
499+
}
500+
wrong_metric_decomposition["context_0"] = ["wrong_metric"]
501+
missing_outcome = "y:context_0"
502+
with self.assertRaisesRegex(
503+
InputDataError, f"{missing_outcome} is missing in metric_decomposition."
504+
):
505+
ContextualDataset(
506+
datasets=dataset_list2,
507+
parameter_decomposition=parameter_decomposition,
508+
context_buckets=context_buckets,
509+
metric_decomposition=wrong_metric_decomposition,
510+
)

0 commit comments

Comments
 (0)