Skip to content

Commit b920207

Browse files
committed
store data in xarray objects in more experiments
1 parent 2726484 commit b920207

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

causalpy/experiments/prepostnegd.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
import pandas as pd
2323
import seaborn as sns
24+
import xarray as xr
2425
from matplotlib import pyplot as plt
2526
from patsy import build_design_matrices, dmatrices
2627
from sklearn.base import RegressorMixin
@@ -111,6 +112,21 @@ def __init__(
111112
self.y, self.X = np.asarray(y), np.asarray(X)
112113
self.outcome_variable_name = y.design_info.column_names[0]
113114

115+
# turn into xarray.DataArray's
116+
self.X = xr.DataArray(
117+
self.X,
118+
dims=["obs_ind", "coeffs"],
119+
coords={
120+
"obs_ind": np.arange(self.X.shape[0]),
121+
"coeffs": self.labels,
122+
},
123+
)
124+
self.y = xr.DataArray(
125+
self.y[:, 0],
126+
dims=["obs_ind"],
127+
coords={"obs_ind": self.data.index},
128+
)
129+
114130
# fit the model to the observed (pre-intervention) data
115131
if isinstance(self.model, PyMCModel):
116132
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}

causalpy/experiments/regression_discontinuity.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from matplotlib import pyplot as plt
2424
from patsy import build_design_matrices, dmatrices
2525
from sklearn.base import RegressorMixin
26-
26+
import xarray as xr
2727
from causalpy.custom_exceptions import (
2828
DataException,
2929
FormulaException,
@@ -121,6 +121,21 @@ def __init__(
121121
self.y, self.X = np.asarray(y), np.asarray(X)
122122
self.outcome_variable_name = y.design_info.column_names[0]
123123

124+
# turn into xarray.DataArray's
125+
self.X = xr.DataArray(
126+
self.X,
127+
dims=["obs_ind", "coeffs"],
128+
coords={
129+
"obs_ind": np.arange(self.X.shape[0]),
130+
"coeffs": self.labels,
131+
},
132+
)
133+
self.y = xr.DataArray(
134+
self.y[:, 0],
135+
dims=["obs_ind"],
136+
coords={"obs_ind": self.data.index},
137+
)
138+
124139
# fit model
125140
if isinstance(self.model, PyMCModel):
126141
# fit the model to the observed (pre-intervention) data

causalpy/experiments/regression_kink.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pandas as pd
2424
import seaborn as sns
2525
from patsy import build_design_matrices, dmatrices
26-
26+
import xarray as xr
2727
from causalpy.plot_utils import plot_xY
2828

2929
from .base import BaseExperiment
@@ -84,6 +84,21 @@ def __init__(
8484
self.y, self.X = np.asarray(y), np.asarray(X)
8585
self.outcome_variable_name = y.design_info.column_names[0]
8686

87+
# turn into xarray.DataArray's
88+
self.X = xr.DataArray(
89+
self.X,
90+
dims=["obs_ind", "coeffs"],
91+
coords={
92+
"obs_ind": np.arange(self.X.shape[0]),
93+
"coeffs": self.labels,
94+
},
95+
)
96+
self.y = xr.DataArray(
97+
self.y[:, 0],
98+
dims=["obs_ind"],
99+
coords={"obs_ind": self.data.index},
100+
)
101+
87102
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
88103
self.model.fit(X=self.X, y=self.y, coords=COORDS)
89104

0 commit comments

Comments
 (0)