File tree Expand file tree Collapse file tree 1 file changed +14
-4
lines changed Expand file tree Collapse file tree 1 file changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -201,10 +201,6 @@ def test_condition_on_observations__(self):
201201
202202 # check that fantasies of batched model are correct
203203 if len (batch_shape ) > 0 and test_X .dim () == 2 :
204- state_dict_non_batch = {
205- key : (val [0 ] if val .ndim > 1 else val )
206- for key , val in model .state_dict ().items ()
207- }
208204 model_kwargs_non_batch = {
209205 "train_X" : train_X [0 ],
210206 "train_Y" : train_Y [0 ],
@@ -213,6 +209,20 @@ def test_condition_on_observations__(self):
213209 if observed_noise :
214210 model_kwargs_non_batch ["train_Yvar" ] = train_Yvar [0 ]
215211 model_non_batch = type (model )(** model_kwargs_non_batch )
212+ non_batch_shapes = {
213+ key : val .shape
214+ for key , val in model_non_batch .state_dict ().items ()
215+ }
216+ state_dict_non_batch = {}
217+ for key , val in model .state_dict ().items ():
218+ if key in non_batch_shapes :
219+ expected_shape = non_batch_shapes [key ]
220+ if val .ndim > len (expected_shape ):
221+ state_dict_non_batch [key ] = val [0 ]
222+ else :
223+ state_dict_non_batch [key ] = val
224+ else :
225+ state_dict_non_batch [key ] = val
216226 model_non_batch .load_state_dict (state_dict_non_batch )
217227 model_non_batch .eval ()
218228 model_non_batch .likelihood .eval ()
You can’t perform that action at this time.
0 commit comments