Commit b0d492d
Preserving train inputs and targets through transforms (#3044)
Summary:
Pull Request resolved: #3044
This PR preserves botorch transforms (specifically outcome_transforms, like Standardize) through state_dict loading. The fix also ensures that train_targets of a Leave-one-out model with outcome transforms will, in the default case, have the same targets as a base model, minus the point left out.
__Longer explanation:__
Transforms, and specifically learnable output transforms like Standardize, will currently:
a. Learn the parameters at initialization of the GP
b. Transform the train_Ys to the normalized space
Then, when we load a state dict, we will:
a. Impose new standardization parameters on already standardized data
b. Potentially make the transforms re-learnable, nullifying the change made by the state dict
This has undesired consequences for cross-validation, as all cross-validated models will effectively have different training data. In essence, _we don't simply leave one point out, but instead we leave one out and re-standardize_. When we have outliers in the data, this will lead to substantially different predictions when the outlier is left out, since the outlier will substantially impact the outcome transform parameters.
Notebook explaining the effect with some plots: N8342965
Reviewed By: Balandat, saitcakmak
Differential Revision: D84571407
fbshipit-source-id: dafffe980d6a853733f9235ac84f2ab424b84f551 parent 09502f9 commit b0d492d
File tree
5 files changed
+440
-12
lines changed- botorch/models
- utils
- test/models
5 files changed
+440
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
| 20 | + | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
37 | 38 | | |
| 39 | + | |
38 | 40 | | |
39 | 41 | | |
40 | 42 | | |
| 43 | + | |
41 | 44 | | |
42 | 45 | | |
43 | 46 | | |
| |||
283 | 286 | | |
284 | 287 | | |
285 | 288 | | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
286 | 386 | | |
287 | 387 | | |
288 | 388 | | |
| |||
659 | 759 | | |
660 | 760 | | |
661 | 761 | | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
662 | 769 | | |
663 | 770 | | |
664 | 771 | | |
| |||
803 | 910 | | |
804 | 911 | | |
805 | 912 | | |
| 913 | + | |
| 914 | + | |
| 915 | + | |
| 916 | + | |
| 917 | + | |
| 918 | + | |
| 919 | + | |
| 920 | + | |
| 921 | + | |
| 922 | + | |
| 923 | + | |
| 924 | + | |
| 925 | + | |
| 926 | + | |
| 927 | + | |
| 928 | + | |
| 929 | + | |
| 930 | + | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
806 | 934 | | |
807 | 935 | | |
808 | 936 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
36 | | - | |
37 | 36 | | |
38 | 37 | | |
39 | 38 | | |
| |||
578 | 577 | | |
579 | 578 | | |
580 | 579 | | |
581 | | - | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
582 | 584 | | |
583 | 585 | | |
584 | 586 | | |
585 | | - | |
586 | | - | |
587 | | - | |
588 | | - | |
589 | | - | |
590 | | - | |
591 | | - | |
592 | | - | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
593 | 593 | | |
594 | 594 | | |
595 | 595 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
15 | 16 | | |
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
33 | 35 | | |
34 | 36 | | |
35 | 37 | | |
| 38 | + | |
| 39 | + | |
36 | 40 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
460 | 461 | | |
461 | 462 | | |
462 | 463 | | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
0 commit comments