88
99from __future__ import annotations
1010
11- from dataclasses import dataclass , fields , MISSING
12- from itertools import chain , count , repeat
11+ import warnings
12+ from itertools import count , repeat
1313from typing import Any , Dict , Hashable , Iterable , Optional , TypeVar , Union
1414
1515from botorch .utils .containers import BotorchContainer , DenseContainer , SliceContainer
1616from torch import long , ones , Tensor
17- from typing_extensions import get_type_hints
1817
1918T = TypeVar ("T" )
2019ContainerLike = Union [BotorchContainer , Tensor ]
2120MaybeIterable = Union [T , Iterable [T ]]
2221
2322
24- @dataclass
25- class BotorchDataset :
26- # TODO: Once v3.10 becomes standard, expose `validate_init` as a kw_only InitVar
27- def __post_init__ (self , validate_init : bool = True ) -> None :
28- if validate_init :
29- self ._validate ()
23+ class SupervisedDataset :
24+ r"""Base class for datasets consisting of labelled pairs `(X, Y)`
25+ and an optional `Yvar` that stipulates observations variances so
26+ that `Y[i] ~ N(f(X[i]), Yvar[i])`.
3027
31- def _validate (self ) -> None :
32- pass
33-
34-
35- class SupervisedDatasetMeta (type ):
36- def __call__ (cls , * args : Any , ** kwargs : Any ):
37- r"""Converts Tensor-valued fields to DenseContainer under the assumption
38- that said fields house collections of feature vectors."""
39- hints = get_type_hints (cls )
40- fields_iter = (item for item in fields (cls ) if item .init is not None )
41- f_dict = {}
42- for value , field in chain (
43- zip (args , fields_iter ),
44- ((kwargs .pop (field .name , MISSING ), field ) for field in fields_iter ),
45- ):
46- if value is MISSING :
47- if field .default is not MISSING :
48- value = field .default
49- elif field .default_factory is not MISSING :
50- value = field .default_factory ()
51- else :
52- raise RuntimeError (f"Missing required field `{ field .name } `." )
53-
54- if issubclass (hints [field .name ], BotorchContainer ):
55- if isinstance (value , Tensor ):
56- value = DenseContainer (value , event_shape = value .shape [- 1 :])
57- elif not isinstance (value , BotorchContainer ):
58- raise TypeError (
59- "Expected <BotorchContainer | Tensor> for field "
60- f"`{ field .name } ` but was { type (value )} ."
61- )
62- f_dict [field .name ] = value
63-
64- return super ().__call__ (** f_dict , ** kwargs )
65-
66-
67- @dataclass
68- class SupervisedDataset (BotorchDataset , metaclass = SupervisedDatasetMeta ):
69- r"""Base class for datasets consisting of labelled pairs `(x, y)`.
70-
71- This class object's `__call__` method converts Tensors `src` to
28+ This class object's `__init__` method converts Tensors `src` to
7229 DenseContainers under the assumption that `event_shape=src.shape[-1:]`.
7330
7431 Example:
@@ -87,6 +44,29 @@ class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta):
8744
8845 X : BotorchContainer
8946 Y : BotorchContainer
47+ Yvar : Optional [BotorchContainer ]
48+
49+ def __init__ (
50+ self ,
51+ X : ContainerLike ,
52+ Y : ContainerLike ,
53+ Yvar : Optional [ContainerLike ] = None ,
54+ validate_init : bool = True ,
55+ ) -> None :
56+ r"""Constructs a `SupervisedDataset`.
57+
58+ Args:
59+ X: A `Tensor` or `BotorchContainer` representing the input features.
60+ Y: A `Tensor` or `BotorchContainer` representing the outcomes.
61+ Yvar: An optional `Tensor` or `BotorchContainer` representing
62+ the observation noise.
63+ validate_init: If `True`, validates the input shapes.
64+ """
65+ self .X = _containerize (X )
66+ self .Y = _containerize (Y )
67+ self .Yvar = None if Yvar is None else _containerize (Yvar )
68+ if validate_init :
69+ self ._validate ()
9070
9171 def _validate (self ) -> None :
9272 shape_X = self .X .shape
@@ -95,12 +75,15 @@ def _validate(self) -> None:
9575 shape_Y = shape_Y [: len (shape_Y ) - len (self .Y .event_shape )]
9676 if shape_X != shape_Y :
9777 raise ValueError ("Batch dimensions of `X` and `Y` are incompatible." )
78+ if self .Yvar is not None and self .Yvar .shape != self .Y .shape :
79+ raise ValueError ("Shapes of `Y` and `Yvar` are incompatible." )
9880
9981 @classmethod
10082 def dict_from_iter (
10183 cls ,
10284 X : MaybeIterable [ContainerLike ],
10385 Y : MaybeIterable [ContainerLike ],
86+ Yvar : Optional [MaybeIterable [ContainerLike ]] = None ,
10487 * ,
10588 keys : Optional [Iterable [Hashable ]] = None ,
10689 ) -> Dict [Hashable , SupervisedDataset ]:
@@ -111,40 +94,46 @@ def dict_from_iter(
11194 X = (X ,) if single_Y else repeat (X )
11295 if single_Y :
11396 Y = (Y ,) if single_X else repeat (Y )
114- return {key : cls (x , y ) for key , x , y in zip (keys or count (), X , Y )}
97+ Yvar = repeat (Yvar ) if isinstance (Yvar , (Tensor , BotorchContainer )) else Yvar
98+
99+ # Pass in Yvar only if it is not None.
100+ iterables = (X , Y ) if Yvar is None else (X , Y , Yvar )
101+ return {
102+ elements [0 ]: cls (* elements [1 :])
103+ for elements in zip (keys or count (), * iterables )
104+ }
105+
106+ def __eq__ (self , other : Any ) -> bool :
107+ return (
108+ type (other ) is type (self )
109+ and self .X == other .X
110+ and self .Y == other .Y
111+ and self .Yvar == other .Yvar
112+ )
115113
116114
117- @dataclass
118115class FixedNoiseDataset (SupervisedDataset ):
119116 r"""A SupervisedDataset with an additional field `Yvar` that stipulates
120- observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`."""
117+ observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`.
121118
122- X : BotorchContainer
123- Y : BotorchContainer
124- Yvar : BotorchContainer
125-
126- @classmethod
127- def dict_from_iter (
128- cls ,
129- X : MaybeIterable [ContainerLike ],
130- Y : MaybeIterable [ContainerLike ],
131- Yvar : Optional [MaybeIterable [ContainerLike ]] = None ,
132- * ,
133- keys : Optional [Iterable [Hashable ]] = None ,
134- ) -> Dict [Hashable , SupervisedDataset ]:
135- r"""Returns a dictionary of `FixedNoiseDataset` from iterables."""
136- single_X = isinstance (X , (Tensor , BotorchContainer ))
137- single_Y = isinstance (Y , (Tensor , BotorchContainer ))
138- if single_X :
139- X = (X ,) if single_Y else repeat (X )
140- if single_Y :
141- Y = (Y ,) if single_X else repeat (Y )
119+ NOTE: This is deprecated. Use `SupervisedDataset` instead.
120+ """
142121
143- Yvar = repeat (Yvar ) if isinstance (Yvar , (Tensor , BotorchContainer )) else Yvar
144- return {key : cls (x , y , c ) for key , x , y , c in zip (keys or count (), X , Y , Yvar )}
122+ def __init__ (
123+ self ,
124+ X : ContainerLike ,
125+ Y : ContainerLike ,
126+ Yvar : ContainerLike ,
127+ validate_init : bool = True ,
128+ ) -> None :
129+ r"""Initialize a `FixedNoiseDataset` -- deprecated!"""
130+ warnings .warn (
131+ "`FixedNoiseDataset` is deprecated. Use `SupervisedDataset` instead." ,
132+ DeprecationWarning ,
133+ )
134+ super ().__init__ (X = X , Y = Y , Yvar = Yvar , validate_init = validate_init )
145135
146136
147- @dataclass
148137class RankingDataset (SupervisedDataset ):
149138 r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
150139 `x ∈ Z^{m}` of elements from a ground set `Z = (z_1, ...)` and ranking vectors
@@ -173,6 +162,18 @@ class RankingDataset(SupervisedDataset):
173162 X : SliceContainer
174163 Y : BotorchContainer
175164
165+ def __init__ (
166+ self , X : SliceContainer , Y : ContainerLike , validate_init : bool = True
167+ ) -> None :
168+ r"""Construct a `RankingDataset`.
169+
170+ Args:
171+ X: A `SliceContainer` representing the input features being ranked.
172+ Y: A `Tensor` or `BotorchContainer` representing the rankings.
173+ validate_init: If `True`, validates the input shapes.
174+ """
175+ super ().__init__ (X = X , Y = Y , Yvar = None , validate_init = validate_init )
176+
176177 def _validate (self ) -> None :
177178 super ()._validate ()
178179
@@ -201,3 +202,13 @@ def _validate(self) -> None:
201202
202203 # Same as: torch.where(y_diff == 0, y_incr + 1, 1)
203204 y_incr = y_incr - y_diff + 1
205+
206+
207+ def _containerize (value : ContainerLike ) -> BotorchContainer :
208+ r"""Converts Tensor-valued arguments to DenseContainer under the assumption
209+ that said arguments house collections of feature vectors.
210+ """
211+ if isinstance (value , Tensor ):
212+ return DenseContainer (value , event_shape = value .shape [- 1 :])
213+ else :
214+ return value
0 commit comments