File tree Expand file tree Collapse file tree 3 files changed +48
-2
lines changed Expand file tree Collapse file tree 3 files changed +48
-2
lines changed Original file line number Diff line number Diff line change 2121import numpy as np
2222import pandas as pd
2323import seaborn as sns
24+ import xarray as xr
2425from matplotlib import pyplot as plt
2526from patsy import build_design_matrices , dmatrices
2627from sklearn .base import RegressorMixin
@@ -111,6 +112,21 @@ def __init__(
111112 self .y , self .X = np .asarray (y ), np .asarray (X )
112113 self .outcome_variable_name = y .design_info .column_names [0 ]
113114
115+ # turn into xarray.DataArray's
116+ self .X = xr .DataArray (
117+ self .X ,
118+ dims = ["obs_ind" , "coeffs" ],
119+ coords = {
120+ "obs_ind" : np .arange (self .X .shape [0 ]),
121+ "coeffs" : self .labels ,
122+ },
123+ )
124+ self .y = xr .DataArray (
125+ self .y [:, 0 ],
126+ dims = ["obs_ind" ],
127+ coords = {"obs_ind" : self .data .index },
128+ )
129+
114130 # fit the model to the observed (pre-intervention) data
115131 if isinstance (self .model , PyMCModel ):
116132 COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .X .shape [0 ])}
Original file line number Diff line number Diff line change 2323from matplotlib import pyplot as plt
2424from patsy import build_design_matrices , dmatrices
2525from sklearn .base import RegressorMixin
26-
26+ import xarray as xr
2727from causalpy .custom_exceptions import (
2828 DataException ,
2929 FormulaException ,
@@ -121,6 +121,21 @@ def __init__(
121121 self .y , self .X = np .asarray (y ), np .asarray (X )
122122 self .outcome_variable_name = y .design_info .column_names [0 ]
123123
124+ # turn into xarray.DataArray's
125+ self .X = xr .DataArray (
126+ self .X ,
127+ dims = ["obs_ind" , "coeffs" ],
128+ coords = {
129+ "obs_ind" : np .arange (self .X .shape [0 ]),
130+ "coeffs" : self .labels ,
131+ },
132+ )
133+ self .y = xr .DataArray (
134+ self .y [:, 0 ],
135+ dims = ["obs_ind" ],
136+ coords = {"obs_ind" : self .data .index },
137+ )
138+
124139 # fit model
125140 if isinstance (self .model , PyMCModel ):
126141 # fit the model to the observed (pre-intervention) data
Original file line number Diff line number Diff line change 2323import pandas as pd
2424import seaborn as sns
2525from patsy import build_design_matrices , dmatrices
26-
26+ import xarray as xr
2727from causalpy .plot_utils import plot_xY
2828
2929from .base import BaseExperiment
@@ -84,6 +84,21 @@ def __init__(
8484 self .y , self .X = np .asarray (y ), np .asarray (X )
8585 self .outcome_variable_name = y .design_info .column_names [0 ]
8686
87+ # turn into xarray.DataArray's
88+ self .X = xr .DataArray (
89+ self .X ,
90+ dims = ["obs_ind" , "coeffs" ],
91+ coords = {
92+ "obs_ind" : np .arange (self .X .shape [0 ]),
93+ "coeffs" : self .labels ,
94+ },
95+ )
96+ self .y = xr .DataArray (
97+ self .y [:, 0 ],
98+ dims = ["obs_ind" ],
99+ coords = {"obs_ind" : self .data .index },
100+ )
101+
87102 COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .X .shape [0 ])}
88103 self .model .fit (X = self .X , y = self .y , coords = COORDS )
89104
You can’t perform that action at this time.
0 commit comments