@@ -258,3 +258,28 @@ def test_geolift1():
258
258
assert isinstance (result , cp .pymc_experiments .SyntheticControl )
259
259
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
260
260
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
261
+
262
+
263
+ @pytest .mark .integration
264
+ def test_iv_reg ():
265
+ df = cp .load_data ("risk" )
266
+ instruments_formula = "risk ~ 1 + logmort0"
267
+ formula = "loggdp ~ 1 + risk"
268
+ instruments_data = df [["risk" , "logmort0" ]]
269
+ data = df [["loggdp" , "risk" ]]
270
+
271
+ result = cp .pymc_experiments .InstrumentalVariable (
272
+ instruments_data = instruments_data ,
273
+ data = data ,
274
+ instruments_formula = instruments_formula ,
275
+ formula = formula ,
276
+ model = cp .pymc_models .InstrumentalVariableRegression (
277
+ sample_kwargs = sample_kwargs
278
+ ),
279
+ )
280
+ assert isinstance (df , pd .DataFrame )
281
+ assert isinstance (data , pd .DataFrame )
282
+ assert isinstance (instruments_data , pd .DataFrame )
283
+ assert isinstance (result , cp .pymc_experiments .InstrumentalVariable )
284
+ assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
285
+ assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
0 commit comments