@@ -115,7 +115,6 @@ def __init__(
115
115
all_tasks : list [int ] | None = None ,
116
116
outcome_transform : OutcomeTransform | _DefaultType | None = DEFAULT ,
117
117
input_transform : InputTransform | None = None ,
118
- validate_task_values : bool = True ,
119
118
) -> None :
120
119
r"""Multi-Task GP model using an ICM kernel.
121
120
@@ -158,9 +157,6 @@ def __init__(
158
157
instantiation of the model.
159
158
input_transform: An input transform that is applied in the model's
160
159
forward pass.
161
- validate_task_values: If True, validate that the task values supplied in the
162
- input are expected tasks values. If false, unexpected task values
163
- will be mapped to the first output_task if supplied.
164
160
165
161
Example:
166
162
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -193,7 +189,7 @@ def __init__(
193
189
"This is not allowed as it will lead to errors during model training."
194
190
)
195
191
all_tasks = all_tasks or all_tasks_inferred
196
- self .num_tasks = len (all_tasks_inferred )
192
+ self .num_tasks = len (all_tasks )
197
193
if outcome_transform == DEFAULT :
198
194
outcome_transform = Standardize (m = 1 , batch_shape = train_X .shape [:- 2 ])
199
195
if outcome_transform is not None :
@@ -263,61 +259,19 @@ def __init__(
263
259
264
260
self .covar_module = data_covar_module * task_covar_module
265
261
task_mapper = get_task_value_remapping (
266
- observed_task_values = torch .tensor (
267
- all_tasks_inferred , dtype = torch .long , device = train_X .device
268
- ),
269
- all_task_values = torch .tensor (
270
- sorted (all_tasks ), dtype = torch .long , device = train_X .device
262
+ task_values = torch .tensor (
263
+ all_tasks , dtype = torch .long , device = train_X .device
271
264
),
272
265
dtype = train_X .dtype ,
273
- default_task_value = None if output_tasks is None else output_tasks [0 ],
274
266
)
275
267
self .register_buffer ("_task_mapper" , task_mapper )
276
- self ._expected_task_values = set (all_tasks_inferred )
268
+ self ._expected_task_values = set (all_tasks )
277
269
if input_transform is not None :
278
270
self .input_transform = input_transform
279
271
if outcome_transform is not None :
280
272
self .outcome_transform = outcome_transform
281
- self ._validate_task_values = validate_task_values
282
273
self .to (train_X )
283
274
284
- def _map_tasks (self , task_values : Tensor ) -> Tensor :
285
- """Map raw task values to the task indices used by the model.
286
-
287
- Args:
288
- task_values: A tensor of task values.
289
-
290
- Returns:
291
- A tensor of task indices with the same shape as the input
292
- tensor.
293
- """
294
- long_task_values = task_values .long ()
295
- if self ._validate_task_values :
296
- if self ._task_mapper is None :
297
- if not (
298
- torch .all (0 <= task_values )
299
- and torch .all (task_values < self .num_tasks )
300
- ):
301
- raise ValueError (
302
- "Expected all task features in `X` to be between 0 and "
303
- f"self.num_tasks - 1. Got { task_values } ."
304
- )
305
- else :
306
- unexpected_task_values = set (
307
- long_task_values .unique ().tolist ()
308
- ).difference (self ._expected_task_values )
309
- if len (unexpected_task_values ) > 0 :
310
- raise ValueError (
311
- "Received invalid raw task values. Expected raw value to be in"
312
- f" { self ._expected_task_values } , but got unexpected task"
313
- f" values: { unexpected_task_values } ."
314
- )
315
- task_values = self ._task_mapper [long_task_values ]
316
- elif self ._task_mapper is not None :
317
- task_values = self ._task_mapper [long_task_values ]
318
-
319
- return task_values
320
-
321
275
def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
322
276
r"""Extracts features before task feature, task indices, and features after
323
277
the task feature.
@@ -330,7 +284,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
330
284
3-element tuple containing
331
285
332
286
- A `q x d` or `b x q x d` tensor with features before the task feature
333
- - A `q` or `b x q x 1 ` tensor with mapped task indices
287
+ - A `q` or `b x q` tensor with mapped task indices
334
288
- A `q x d` or `b x q x d` tensor with features after the task feature
335
289
"""
336
290
batch_shape = x .shape [:- 2 ]
@@ -370,7 +324,7 @@ def get_all_tasks(
370
324
raise ValueError (f"Must have that -{ d } <= task_feature <= { d } " )
371
325
task_feature = task_feature % (d + 1 )
372
326
all_tasks = (
373
- train_X [..., task_feature ].to (dtype = torch .long ). unique ( sorted = True ).tolist ()
327
+ train_X [..., task_feature ].unique ( sorted = True ). to (dtype = torch .long ).tolist ()
374
328
)
375
329
return all_tasks , task_feature , d
376
330
0 commit comments