@@ -40,6 +40,24 @@ def toy_y(toy_X):
4040 return y
4141
4242
43+ @pytest .fixture (scope = "module" )
44+ def fitted_model_instance (toy_X , toy_y ):
45+ sampler_config = {
46+ "draws" : 500 ,
47+ "tune" : 300 ,
48+ "chains" : 2 ,
49+ "target_accept" : 0.95 ,
50+ }
51+ model_config = {
52+ "a" : {"loc" : 0 , "scale" : 10 },
53+ "b" : {"loc" : 0 , "scale" : 10 },
54+ "obs_error" : 2 ,
55+ }
56+ model = test_ModelBuilder (model_config = model_config , sampler_config = sampler_config )
57+ model .fit (toy_X )
58+ return model
59+
60+
4361class test_ModelBuilder (ModelBuilder ):
4462
4563 _model_type = "LinearModel"
@@ -103,16 +121,11 @@ def default_sampler_config(self) -> Dict:
103121 "target_accept" : 0.95 ,
104122 }
105123
106- @staticmethod
107- def initial_build_and_fit (toy_X , toy_y , check_idata = True ) -> ModelBuilder :
108- model_builder = test_ModelBuilder ()
109- model_builder .idata = model_builder .fit (
110- toy_X , toy_y , predictor_names = ["input" ], random_seed = 1234
111- )
112- if check_idata :
113- assert model_builder .idata is not None
114- assert "posterior" in model_builder .idata .groups ()
115- return model_builder
124+
125+ def test_initial_build_and_fit (fitted_model_instance , check_idata = True ) -> ModelBuilder :
126+ if check_idata :
127+ assert fitted_model_instance .idata is not None
128+ assert "posterior" in fitted_model_instance .idata .groups ()
116129
117130
118131def test_save_without_fit_raises_runtime_error ():
@@ -129,59 +142,62 @@ def test_empty_sampler_config_fit(toy_X, toy_y):
129142 assert "posterior" in model_builder .idata .groups ()
130143
131144
132- def test_fit (toy_X , toy_y ):
133- model_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
134- x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
135- prediction_data = pd .DataFrame ({"input" : x_pred })
136- pred = model_builder .predict (prediction_data ["input" ])
137- post_pred = model_builder .sample_posterior_predictive (
145+ def test_fit (fitted_model_instance ):
146+ prediction_data = pd .DataFrame ({"input" : np .random .uniform (low = 0 , high = 1 , size = 100 )})
147+ pred = fitted_model_instance .predict (prediction_data ["input" ])
148+ post_pred = fitted_model_instance .sample_posterior_predictive (
138149 prediction_data ["input" ], extend_idata = True , combined = True
139150 )
140- post_pred [model_builder .output_var ].shape [0 ] == prediction_data .input .shape
151+ post_pred [fitted_model_instance .output_var ].shape [0 ] == prediction_data .input .shape
152+
153+
154+ def test_fit_no_y (toy_X ):
155+ model_builder = test_ModelBuilder ()
156+ model_builder .idata = model_builder .fit (X = toy_X )
157+ assert model_builder .model is not None
158+ assert model_builder .idata is not None
159+ assert "posterior" in model_builder .idata .groups ()
141160
142161
143162@pytest .mark .skipif (
144163 sys .platform == "win32" , reason = "Permissions for temp files not granted on windows CI."
145164)
146- def test_save_load (toy_X , toy_y ):
147- test_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
165+ def test_save_load (fitted_model_instance ):
148166 temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
149- test_builder .save (temp .name )
167+ fitted_model_instance .save (temp .name )
150168 test_builder2 = test_ModelBuilder .load (temp .name )
151- assert test_builder .idata .groups () == test_builder2 .idata .groups ()
169+ assert fitted_model_instance .idata .groups () == test_builder2 .idata .groups ()
152170
153171 x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
154172 prediction_data = pd .DataFrame ({"input" : x_pred })
155- pred1 = test_builder .predict (prediction_data ["input" ])
173+ pred1 = fitted_model_instance .predict (prediction_data ["input" ])
156174 pred2 = test_builder2 .predict (prediction_data ["input" ])
157175 assert pred1 .shape == pred2 .shape
158176 temp .close ()
159177
160178
161- def test_predict (toy_X , toy_y ):
162- model_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
179+ def test_predict (fitted_model_instance ):
163180 x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
164181 prediction_data = pd .DataFrame ({"input" : x_pred })
165- pred = model_builder .predict (prediction_data ["input" ])
182+ pred = fitted_model_instance .predict (prediction_data ["input" ])
166183 # Perform elementwise comparison using numpy
167184 assert type (pred ) == np .ndarray
168185 assert len (pred ) > 0
169186
170187
171188@pytest .mark .parametrize ("combined" , [True , False ])
172- def test_sample_posterior_predictive (toy_X , toy_y , combined ):
173- model_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
189+ def test_sample_posterior_predictive (fitted_model_instance , combined ):
174190 n_pred = 100
175191 x_pred = np .random .uniform (low = 0 , high = 1 , size = n_pred )
176192 prediction_data = pd .DataFrame ({"input" : x_pred })
177- pred = model_builder .sample_posterior_predictive (
193+ pred = fitted_model_instance .sample_posterior_predictive (
178194 prediction_data ["input" ], combined = combined , extend_idata = True
179195 )
180- chains = model_builder .idata .sample_stats .dims ["chain" ]
181- draws = model_builder .idata .sample_stats .dims ["draw" ]
196+ chains = fitted_model_instance .idata .sample_stats .dims ["chain" ]
197+ draws = fitted_model_instance .idata .sample_stats .dims ["draw" ]
182198 expected_shape = (n_pred , chains * draws ) if combined else (chains , draws , n_pred )
183- assert pred [model_builder .output_var ].shape == expected_shape
184- assert np .issubdtype (pred [model_builder .output_var ].dtype , np .floating )
199+ assert pred [fitted_model_instance .output_var ].shape == expected_shape
200+ assert np .issubdtype (pred [fitted_model_instance .output_var ].dtype , np .floating )
185201
186202
187203def test_id ():
0 commit comments