@@ -216,16 +216,28 @@ class LinearRegression(PyMCModel):
216216 --------
217217 >>> import causalpy as cp
218218 >>> import numpy as np
219+ >>> import xarray as xr
219220 >>> from causalpy.pymc_models import LinearRegression
220221 >>> rd = cp.load_data("rd")
221- >>> X = rd[["x", "treated"]]
222- >>> y = np.asarray(rd["y"]).reshape((rd["y"].shape[0],1))
222+ >>> X = xr.DataArray(
223+ >>> rd[["x", "treated"]],
224+ >>> dims=["obs_ind", "coeffs"],
225+ >>> coords={
226+ >>> "obs_ind": rd.index,
227+ >>> "coeffs":coeffs,
228+ >>> },
229+ >>> )
230+ >>> y = xr.DataArray(
231+ >>> np.asarray(rd["y"]),
232+ >>> dims=["obs_ind"],
233+ >>> coords={"obs_ind": rd.index},
234+ >>> )
223235 >>> lr = LinearRegression(sample_kwargs={"progressbar": False})
224- >>> lr.fit(X, y, coords={
225- ... ' coeffs': ['x', 'treated'] ,
226- ... ' obs_ind' : np.arange(rd.shape[0])
227- ... },
228- ... )
236+ >>> coords={
237+ >>> " coeffs": coeffs ,
238+ >>> " obs_ind" : np.arange(rd.shape[0]),
239+ >>> }
240+ >>> lr.fit(X, y, coords=coords )
229241 Inference data...
230242 """ # noqa: W605
231243
@@ -264,7 +276,7 @@ class WeightedSumFitter(PyMCModel):
264276 >>> X = sc[['a', 'b', 'c', 'd', 'e', 'f', 'g']]
265277 >>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
266278 >>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
267- >>> wsf.fit(X,y)
279+ >>> wsf.fit(X, y)
268280 Inference data...
269281 """ # noqa: W605
270282
0 commit comments