@@ -17,33 +17,52 @@ def test_fixed_features(self):
1717 train_Y = train_X .norm (dim = - 1 , keepdim = True )
1818 model = SingleTaskGP (train_X , train_Y ).to (device = self .device ).eval ()
1919 qEI = qExpectedImprovement (model , best_f = 0.0 )
20- # test single point
21- test_X = torch .rand (1 , 3 , device = self .device )
22- qEI_ff = FixedFeatureAcquisitionFunction (
23- qEI , d = 3 , columns = [2 ], values = test_X [..., - 1 :]
24- )
25- qei = qEI (test_X )
26- qei_ff = qEI_ff (test_X [..., :- 1 ])
27- self .assertTrue (torch .allclose (qei , qei_ff ))
28- # test list input
29- qEI_ff = FixedFeatureAcquisitionFunction (qEI , d = 3 , columns = [2 ], values = [0.5 ])
30- qei_ff = qEI_ff (test_X [..., :- 1 ])
31- # test q-batch
32- test_X = torch .rand (2 , 3 , device = self .device )
33- qEI_ff = FixedFeatureAcquisitionFunction (
34- qEI , d = 3 , columns = [1 ], values = test_X [..., [1 ]]
35- )
36- qei = qEI (test_X )
37- qei_ff = qEI_ff (test_X [..., [0 , 2 ]])
38- self .assertTrue (torch .allclose (qei , qei_ff ))
39- # test t-batch with broadcasting
40- test_X = torch .rand (2 , 3 , device = self .device ).expand (4 , 2 , 3 )
41- qEI_ff = FixedFeatureAcquisitionFunction (
42- qEI , d = 3 , columns = [2 ], values = test_X [0 , :, - 1 :]
43- )
44- qei = qEI (test_X )
45- qei_ff = qEI_ff (test_X [..., :- 1 ])
46- self .assertTrue (torch .allclose (qei , qei_ff ))
20+ for q in [1 , 2 ]:
21+ # test single point
22+ test_X = torch .rand (q , 3 , device = self .device )
23+ qEI_ff = FixedFeatureAcquisitionFunction (
24+ qEI , d = 3 , columns = [2 ], values = test_X [..., - 1 :]
25+ )
26+ qei = qEI (test_X )
27+ qei_ff = qEI_ff (test_X [..., :- 1 ])
28+ self .assertTrue (torch .allclose (qei , qei_ff ))
29+
30+ # test list input with float
31+ qEI_ff = FixedFeatureAcquisitionFunction (
32+ qEI , d = 3 , columns = [2 ], values = [0.5 ]
33+ )
34+ qei_ff = qEI_ff (test_X [..., :- 1 ])
35+ test_X_clone = test_X .clone ()
36+ test_X_clone [..., 2 ] = 0.5
37+ qei = qEI (test_X_clone )
38+ self .assertTrue (torch .allclose (qei , qei_ff ))
39+
40+ # test list input with Tensor and float
41+ qEI_ff = FixedFeatureAcquisitionFunction (
42+ qEI , d = 3 , columns = [0 , 2 ], values = [test_X [..., [0 ]], 0.5 ]
43+ )
44+ qei_ff = qEI_ff (test_X [..., [1 ]])
45+ self .assertTrue (torch .allclose (qei , qei_ff ))
46+
47+ # test t-batch with broadcasting and list of floats
48+ test_X = torch .rand (q , 3 , device = self .device ).expand (4 , q , 3 )
49+ qEI_ff = FixedFeatureAcquisitionFunction (
50+ qEI , d = 3 , columns = [2 ], values = test_X [0 , :, - 1 :]
51+ )
52+ qei = qEI (test_X )
53+ qei_ff = qEI_ff (test_X [..., :- 1 ])
54+ self .assertTrue (torch .allclose (qei , qei_ff ))
55+
56+ # test t-batch with broadcasting and list of floats and Tensor
57+ qEI_ff = FixedFeatureAcquisitionFunction (
58+ qEI , d = 3 , columns = [0 , 2 ], values = [test_X [0 , :, [0 ]], 0.5 ]
59+ )
60+ test_X_clone = test_X .clone ()
61+ test_X_clone [..., 2 ] = 0.5
62+ qei = qEI (test_X_clone )
63+ qei_ff = qEI_ff (test_X [..., [1 ]])
64+ self .assertTrue (torch .allclose (qei , qei_ff ))
65+
4766 # test gradient
4867 test_X = torch .rand (1 , 3 , device = self .device , requires_grad = True )
4968 test_X_ff = test_X [..., :- 1 ].detach ().clone ().requires_grad_ (True )
@@ -56,6 +75,20 @@ def test_fixed_features(self):
5675 qei .backward ()
5776 qei_ff .backward ()
5877 self .assertTrue (torch .allclose (test_X .grad [..., :- 1 ], test_X_ff .grad ))
78+
79+ test_X = test_X .detach ().clone ()
80+ test_X_ff = test_X [..., [1 ]].detach ().clone ().requires_grad_ (True )
81+ test_X [..., 2 ] = 0.5
82+ test_X .requires_grad_ (True )
83+ qei = qEI (test_X )
84+ qEI_ff = FixedFeatureAcquisitionFunction (
85+ qEI , d = 3 , columns = [0 , 2 ], values = [test_X [..., [0 ]].detach (), 0.5 ]
86+ )
87+ qei_ff = qEI_ff (test_X_ff )
88+ qei .backward ()
89+ qei_ff .backward ()
90+ self .assertTrue (torch .allclose (test_X .grad [..., [1 ]], test_X_ff .grad ))
91+
5992 # test error b/c of incompatible input shapes
6093 with self .assertRaises (ValueError ):
6194 qEI_ff (test_X )
0 commit comments