-
Notifications
You must be signed in to change notification settings - Fork 456
Description
What happened?
I have a specific application in which MultiTaskGP receives some tasks not present in the training data. As far as I understand, the solution is to set the all_tasks argument of the MultiTaskGP.
It worked fine on botorch==0.15.0, but it's not working on botorch==0.16.0. Even if I set all_tasks it raises ValueError: Received invalid raw task values. I also verified that the num_tasks property is set to the number of inferred tasks from the training data, not the one passed in all_tasks.
I noticed that there were some recent changes (#3006) in which the logic regarding missing tasks was changed, resulting in the substitution of all_tasks by all_tasks_inferred in many parts of the code. I was not able to find much information about the changes, so I apologize if I'm getting this error because I'm misunderstanding how to use the class.
I'm happy to help if there's anything that needs to be changed.
Thanks!
Please provide a minimal, reproducible example of the unexpected behavior.
import torch
from botorch.models.multitask import MultiTaskGP
# Creating training data with 3 tasks (0, 1, 2)
train_data = torch.randn(10, 2, dtype=torch.float64)
train_tasks = torch.randint(0, 3, (10, 1))
train_X = torch.cat([train_data, train_tasks], dim=1)
train_Y = torch.randn(10, 1, dtype=torch.float64)
# Creating testing data with 5 tasks (0, 1, 2, 3, 4)
test_data = torch.randn(5, 2, dtype=torch.float64)
test_tasks = torch.arange(5).reshape(-1, 1)
test_X = torch.cat([test_data, test_tasks], dim=1)
# Creating multi-task GP with 5 tasks
gp = MultiTaskGP(
train_X=train_X, train_Y=train_Y, task_feature=2, all_tasks=list(range(5))
)
print(gp.num_tasks) # outputs 3 (number of inferred tasks)
gp.posterior(test_X) # raises errorPlease paste any relevant traceback/logs produced by the example provided.
Traceback (most recent call last):
File "/Users/gsutterp/tmp/multitaskgp.py", line 21, in <module>
gp.posterior(test_X)
~~~~~~~~~~~~^^^^^^^^
File "/Users/gsutterp/base_env/lib/python3.13/site-packages/botorch/models/gpytorch.py", line 1059, in posterior
task_features = self._map_tasks(task_values=task_features)
File "/Users/gsutterp/base_env/lib/python3.13/site-packages/botorch/models/multitask.py", line 310, in _map_tasks
raise ValueError(
...<3 lines>...
)
ValueError: Received invalid raw task values. Expected raw value to be in {0, 1, 2}, but got unexpected task values: {3, 4}.BoTorch Version
0.16.0
Python Version
No response
Operating System
No response
(Optional) Describe any potential fixes you've considered to the issue outlined above.
No response
Pull Request
None
Code of Conduct
- I agree to follow BoTorch's Code of Conduct