File tree Expand file tree Collapse file tree 3 files changed +48
-2
lines changed Expand file tree Collapse file tree 3 files changed +48
-2
lines changed Original file line number Diff line number Diff line change 21
21
import numpy as np
22
22
import pandas as pd
23
23
import seaborn as sns
24
+ import xarray as xr
24
25
from matplotlib import pyplot as plt
25
26
from patsy import build_design_matrices , dmatrices
26
27
from sklearn .base import RegressorMixin
@@ -111,6 +112,21 @@ def __init__(
111
112
self .y , self .X = np .asarray (y ), np .asarray (X )
112
113
self .outcome_variable_name = y .design_info .column_names [0 ]
113
114
115
+ # turn into xarray.DataArray's
116
+ self .X = xr .DataArray (
117
+ self .X ,
118
+ dims = ["obs_ind" , "coeffs" ],
119
+ coords = {
120
+ "obs_ind" : np .arange (self .X .shape [0 ]),
121
+ "coeffs" : self .labels ,
122
+ },
123
+ )
124
+ self .y = xr .DataArray (
125
+ self .y [:, 0 ],
126
+ dims = ["obs_ind" ],
127
+ coords = {"obs_ind" : self .data .index },
128
+ )
129
+
114
130
# fit the model to the observed (pre-intervention) data
115
131
if isinstance (self .model , PyMCModel ):
116
132
COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .X .shape [0 ])}
Original file line number Diff line number Diff line change 23
23
from matplotlib import pyplot as plt
24
24
from patsy import build_design_matrices , dmatrices
25
25
from sklearn .base import RegressorMixin
26
-
26
+ import xarray as xr
27
27
from causalpy .custom_exceptions import (
28
28
DataException ,
29
29
FormulaException ,
@@ -121,6 +121,21 @@ def __init__(
121
121
self .y , self .X = np .asarray (y ), np .asarray (X )
122
122
self .outcome_variable_name = y .design_info .column_names [0 ]
123
123
124
+ # turn into xarray.DataArray's
125
+ self .X = xr .DataArray (
126
+ self .X ,
127
+ dims = ["obs_ind" , "coeffs" ],
128
+ coords = {
129
+ "obs_ind" : np .arange (self .X .shape [0 ]),
130
+ "coeffs" : self .labels ,
131
+ },
132
+ )
133
+ self .y = xr .DataArray (
134
+ self .y [:, 0 ],
135
+ dims = ["obs_ind" ],
136
+ coords = {"obs_ind" : self .data .index },
137
+ )
138
+
124
139
# fit model
125
140
if isinstance (self .model , PyMCModel ):
126
141
# fit the model to the observed (pre-intervention) data
Original file line number Diff line number Diff line change 23
23
import pandas as pd
24
24
import seaborn as sns
25
25
from patsy import build_design_matrices , dmatrices
26
-
26
+ import xarray as xr
27
27
from causalpy .plot_utils import plot_xY
28
28
29
29
from .base import BaseExperiment
@@ -84,6 +84,21 @@ def __init__(
84
84
self .y , self .X = np .asarray (y ), np .asarray (X )
85
85
self .outcome_variable_name = y .design_info .column_names [0 ]
86
86
87
+ # turn into xarray.DataArray's
88
+ self .X = xr .DataArray (
89
+ self .X ,
90
+ dims = ["obs_ind" , "coeffs" ],
91
+ coords = {
92
+ "obs_ind" : np .arange (self .X .shape [0 ]),
93
+ "coeffs" : self .labels ,
94
+ },
95
+ )
96
+ self .y = xr .DataArray (
97
+ self .y [:, 0 ],
98
+ dims = ["obs_ind" ],
99
+ coords = {"obs_ind" : self .data .index },
100
+ )
101
+
87
102
COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .X .shape [0 ])}
88
103
self .model .fit (X = self .X , y = self .y , coords = COORDS )
89
104
You can’t perform that action at this time.
0 commit comments