@@ -74,7 +74,7 @@ def _get_model(
7474 1.1000 , ** tkwargs
7575 ),
7676 "likelihood.noise_covar.noise_prior.rate" : torch .tensor (0.0500 , ** tkwargs ),
77- "mean_module.constant " : torch .tensor ([ 0.1398 ] , ** tkwargs ),
77+ "mean_module.raw_constant " : torch .tensor (0.1398 , ** tkwargs ),
7878 "covar_module.raw_outputscale" : torch .tensor (0.6933 , ** tkwargs ),
7979 "covar_module.base_kernel.raw_lengthscale" : torch .tensor (
8080 [[- 0.0444 ]], ** tkwargs
@@ -111,8 +111,8 @@ def _get_model(
111111 torch .tensor ([0.0745 ], ** tkwargs ),
112112 ]
113113 )
114- state_dict ["mean_module.constant " ] = torch .stack (
115- [state_dict ["mean_module.constant " ], torch .tensor ([ 0.3276 ] , ** tkwargs )]
114+ state_dict ["mean_module.raw_constant " ] = torch .stack (
115+ [state_dict ["mean_module.raw_constant " ], torch .tensor (0.3276 , ** tkwargs )]
116116 )
117117 state_dict ["covar_module.raw_outputscale" ] = torch .stack (
118118 [
@@ -134,7 +134,7 @@ def _get_model(
134134 state_dict ["likelihood.noise_covar.raw_noise" ] = torch .tensor (
135135 [[0.0214 ], [0.001 ]], ** tkwargs
136136 )
137- state_dict ["mean_module.constant " ] = torch .tensor ([[ 0.1398 ], [ 0.5 ] ], ** tkwargs )
137+ state_dict ["mean_module.raw_constant " ] = torch .tensor ([0.1398 , 0.5 ], ** tkwargs )
138138 state_dict ["covar_module.raw_outputscale" ] = torch .tensor (
139139 [0.6933 , 1.0 ], ** tkwargs
140140 )
@@ -153,8 +153,8 @@ def _get_model(
153153 state_dict ["likelihood.noise_covar.raw_noise" ] = torch .tensor (
154154 [[0.1743 ], [0.3132 ]] if multi_output else [0.1743 ], ** tkwargs
155155 )
156- state_dict ["mean_module.constant " ] = torch .tensor (
157- [[ 0.2560 ], [ 0.6714 ]] if multi_output else [ 0.2555 ] , ** tkwargs
156+ state_dict ["mean_module.raw_constant " ] = torch .tensor (
157+ [0.2560 , 0.6714 ] if multi_output else 0.2555 , ** tkwargs
158158 )
159159 state_dict ["covar_module.raw_outputscale" ] = torch .tensor (
160160 [2.4396 , 2.6821 ] if multi_output else 2.4398 , ** tkwargs
@@ -187,15 +187,15 @@ def test_gp_draw_single_output(self):
187187 for dtype in (torch .float , torch .double ):
188188 tkwargs = {"device" : self .device , "dtype" : dtype }
189189 model , _ , _ = _get_model (** tkwargs )
190- mean = model .mean_module .constant .detach ().clone ()
190+ mean = model .mean_module .raw_constant .detach ().clone ()
191191 gp = GPDraw (model )
192192 # test initialization
193193 self .assertIsNone (gp .Xs )
194194 self .assertIsNone (gp .Ys )
195195 self .assertIsNotNone (gp ._seed )
196196 # make sure model is actually deepcopied
197- model .mean_module .constant = None
198- self .assertTrue (torch .equal (gp ._model .mean_module .constant , mean ))
197+ model .mean_module .constant = float ( "inf" )
198+ self .assertTrue (torch .equal (gp ._model .mean_module .raw_constant , mean ))
199199 # test basic functionality
200200 test_X1 = torch .rand (1 , 1 , ** tkwargs , requires_grad = True )
201201 Y1 = gp (test_X1 )
@@ -234,14 +234,14 @@ def test_gp_draw_multi_output(self):
234234 for dtype in (torch .float , torch .double ):
235235 tkwargs = {"device" : self .device , "dtype" : dtype }
236236 model , _ , _ = _get_model (** tkwargs , multi_output = True )
237- mean = model .mean_module .constant .detach ().clone ()
237+ mean = model .mean_module .raw_constant .detach ().clone ()
238238 gp = GPDraw (model )
239239 # test initialization
240240 self .assertIsNone (gp .Xs )
241241 self .assertIsNone (gp .Ys )
242242 # make sure model is actually deepcopied
243- model .mean_module .constant = None
244- self .assertTrue (torch .equal (gp ._model .mean_module .constant , mean ))
243+ model .mean_module .constant = float ( "inf" )
244+ self .assertTrue (torch .equal (gp ._model .mean_module .raw_constant , mean ))
245245 # test basic functionality
246246 test_X1 = torch .rand (1 , 1 , ** tkwargs , requires_grad = True )
247247 Y1 = gp (test_X1 )
0 commit comments