Skip to content

Commit a28c5da

Browse files
committed
attempt to make LinearRegression doctest pass
1 parent b920207 commit a28c5da

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

causalpy/pymc_models.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,16 +216,28 @@ class LinearRegression(PyMCModel):
216216
--------
217217
>>> import causalpy as cp
218218
>>> import numpy as np
219+
>>> import xarray as xr
219220
>>> from causalpy.pymc_models import LinearRegression
220221
>>> rd = cp.load_data("rd")
221-
>>> X = rd[["x", "treated"]]
222-
>>> y = np.asarray(rd["y"]).reshape((rd["y"].shape[0],1))
222+
>>> X = xr.DataArray(
223+
>>> rd[["x", "treated"]],
224+
>>> dims=["obs_ind", "coeffs"],
225+
>>> coords={
226+
>>> "obs_ind": rd.index,
227+
>>> "coeffs":coeffs,
228+
>>> },
229+
>>> )
230+
>>> y = xr.DataArray(
231+
>>> np.asarray(rd["y"]),
232+
>>> dims=["obs_ind"],
233+
>>> coords={"obs_ind": rd.index},
234+
>>> )
223235
>>> lr = LinearRegression(sample_kwargs={"progressbar": False})
224-
>>> lr.fit(X, y, coords={
225-
... 'coeffs': ['x', 'treated'],
226-
... 'obs_ind': np.arange(rd.shape[0])
227-
... },
228-
... )
236+
>>> coords={
237+
>>> "coeffs": coeffs,
238+
>>> "obs_ind": np.arange(rd.shape[0]),
239+
>>> }
240+
>>> lr.fit(X, y, coords=coords)
229241
Inference data...
230242
""" # noqa: W605
231243

@@ -264,7 +276,7 @@ class WeightedSumFitter(PyMCModel):
264276
>>> X = sc[['a', 'b', 'c', 'd', 'e', 'f', 'g']]
265277
>>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
266278
>>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
267-
>>> wsf.fit(X,y)
279+
>>> wsf.fit(X, y)
268280
Inference data...
269281
""" # noqa: W605
270282

0 commit comments

Comments
 (0)