Skip to content

Commit 1f753e9

Browse files
committed
final doctest now passes 😍
1 parent b49ed7e commit 1f753e9

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

causalpy/pymc_models.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -219,24 +219,20 @@ class LinearRegression(PyMCModel):
219219
>>> import xarray as xr
220220
>>> from causalpy.pymc_models import LinearRegression
221221
>>> rd = cp.load_data("rd")
222+
>>> rd["treated"] = rd["treated"].astype(int)
223+
>>> coeffs = ["x", "treated"]
222224
>>> X = xr.DataArray(
223-
>>> rd[["x", "treated"]],
224-
>>> dims=["obs_ind", "coeffs"],
225-
>>> coords={
226-
>>> "obs_ind": rd.index,
227-
>>> "coeffs":coeffs,
228-
>>> },
229-
>>> )
225+
... rd[coeffs].values,
226+
... dims=["obs_ind", "coeffs"],
227+
... coords={"obs_ind": rd.index, "coeffs": coeffs},
228+
... )
230229
>>> y = xr.DataArray(
231-
>>> np.asarray(rd["y"]),
232-
>>> dims=["obs_ind"],
233-
>>> coords={"obs_ind": rd.index},
234-
>>> )
230+
... rd["y"].values,
231+
... dims=["obs_ind"],
232+
... coords={"obs_ind": rd.index},
233+
... )
235234
>>> lr = LinearRegression(sample_kwargs={"progressbar": False})
236-
>>> coords={
237-
>>> "coeffs": coeffs,
238-
>>> "obs_ind": np.arange(rd.shape[0]),
239-
>>> }
235+
>>> coords={"coeffs": coeffs, "obs_ind": np.arange(rd.shape[0])}
240236
>>> lr.fit(X, y, coords=coords)
241237
Inference data...
242238
""" # noqa: W605

0 commit comments

Comments
 (0)