@@ -115,6 +115,7 @@ def __init__(
115
115
all_tasks : list [int ] | None = None ,
116
116
outcome_transform : OutcomeTransform | _DefaultType | None = DEFAULT ,
117
117
input_transform : InputTransform | None = None ,
118
+ validate_task_values : bool = True ,
118
119
) -> None :
119
120
r"""Multi-Task GP model using an ICM kernel.
120
121
@@ -157,6 +158,9 @@ def __init__(
157
158
instantiation of the model.
158
159
input_transform: An input transform that is applied in the model's
159
160
forward pass.
161
+ validate_task_values: If True, validate that the task values supplied in the
162
+ input are expected tasks values. If false, unexpected task values
163
+ will be mapped to the first output_task if supplied.
160
164
161
165
Example:
162
166
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -189,7 +193,7 @@ def __init__(
189
193
"This is not allowed as it will lead to errors during model training."
190
194
)
191
195
all_tasks = all_tasks or all_tasks_inferred
192
- self .num_tasks = len (all_tasks )
196
+ self .num_tasks = len (all_tasks_inferred )
193
197
if outcome_transform == DEFAULT :
194
198
outcome_transform = Standardize (m = 1 , batch_shape = train_X .shape [:- 2 ])
195
199
if outcome_transform is not None :
@@ -259,19 +263,61 @@ def __init__(
259
263
260
264
self .covar_module = data_covar_module * task_covar_module
261
265
task_mapper = get_task_value_remapping (
262
- task_values = torch .tensor (
263
- all_tasks , dtype = torch .long , device = train_X .device
266
+ observed_task_values = torch .tensor (
267
+ all_tasks_inferred , dtype = torch .long , device = train_X .device
268
+ ),
269
+ all_task_values = torch .tensor (
270
+ sorted (all_tasks ), dtype = torch .long , device = train_X .device
264
271
),
265
272
dtype = train_X .dtype ,
273
+ default_task_value = None if output_tasks is None else output_tasks [0 ],
266
274
)
267
275
self .register_buffer ("_task_mapper" , task_mapper )
268
- self ._expected_task_values = set (all_tasks )
276
+ self ._expected_task_values = set (all_tasks_inferred )
269
277
if input_transform is not None :
270
278
self .input_transform = input_transform
271
279
if outcome_transform is not None :
272
280
self .outcome_transform = outcome_transform
281
+ self ._validate_task_values = validate_task_values
273
282
self .to (train_X )
274
283
284
+ def _map_tasks (self , task_values : Tensor ) -> Tensor :
285
+ """Map raw task values to the task indices used by the model.
286
+
287
+ Args:
288
+ task_values: A tensor of task values.
289
+
290
+ Returns:
291
+ A tensor of task indices with the same shape as the input
292
+ tensor.
293
+ """
294
+ long_task_values = task_values .long ()
295
+ if self ._validate_task_values :
296
+ if self ._task_mapper is None :
297
+ if not (
298
+ torch .all (0 <= task_values )
299
+ and torch .all (task_values < self .num_tasks )
300
+ ):
301
+ raise ValueError (
302
+ "Expected all task features in `X` to be between 0 and "
303
+ f"self.num_tasks - 1. Got { task_values } ."
304
+ )
305
+ else :
306
+ unexpected_task_values = set (
307
+ long_task_values .unique ().tolist ()
308
+ ).difference (self ._expected_task_values )
309
+ if len (unexpected_task_values ) > 0 :
310
+ raise ValueError (
311
+ "Received invalid raw task values. Expected raw value to be in"
312
+ f" { self ._expected_task_values } , but got unexpected task"
313
+ f" values: { unexpected_task_values } ."
314
+ )
315
+ task_values = self ._task_mapper [long_task_values ]
316
+ elif self ._task_mapper is not None :
317
+ task_values = self ._task_mapper [long_task_values ]
318
+
319
+ return task_values
320
+
275
321
def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
276
322
r"""Extracts features before task feature, task indices, and features after
277
323
the task feature.
@@ -284,7 +330,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
284
330
3-element tuple containing
285
331
286
332
- A `q x d` or `b x q x d` tensor with features before the task feature
287
- - A `q` or `b x q` tensor with mapped task indices
333
+ - A `q` or `b x q x 1 ` tensor with mapped task indices
288
334
- A `q x d` or `b x q x d` tensor with features after the task feature
289
335
"""
290
336
batch_shape = x .shape [:- 2 ]
@@ -324,7 +370,7 @@ def get_all_tasks(
324
370
raise ValueError (f"Must have that -{ d } <= task_feature <= { d } " )
325
371
task_feature = task_feature % (d + 1 )
326
372
all_tasks = (
327
- train_X [..., task_feature ].unique ( sorted = True ). to (dtype = torch .long ).tolist ()
373
+ train_X [..., task_feature ].to (dtype = torch .long ). unique ( sorted = True ).tolist ()
328
374
)
329
375
return all_tasks , task_feature , d
330
376
0 commit comments