@@ -789,19 +789,26 @@ class InterventionTimeEstimator(PyMCModel):
789789 >>> labels = X.design_info.column_names
790790 >>> _y, _X = np.asarray(y), np.asarray(X)
791791 >>> _X = xr.DataArray(
792- ... _X,
793- ... dims=["obs_ind", "coeffs"],
794- ... coords={
795- ... "obs_ind": data.index,
796- ... "coeffs": labels,
797- ... },
792+ ... _X,
793+ ... dims=["obs_ind", "coeffs"],
794+ ... coords={
795+ ... "obs_ind": data.index,
796+ ... "coeffs": labels,
797+ ... },
798798 ... )
799799 >>> _y = xr.DataArray(
800- ... _y[:, 0],
801- ... dims=["obs_ind"],
802- ... coords={"obs_ind": data.index},
803- ... )
804- >>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
800+ ... _y,
801+ ... dims=["obs_ind", "treated_units"],
802+ ... coords={
803+ ... "obs_ind": data.index,
804+ ... "treated_units": ["unit_0"]
805+ ... },
806+ ... )
807+ >>> COORDS = {
808+ ... "coeffs": labels,
809+ ... "obs_ind": np.arange(X.shape[0]),
810+ ... "treated_units": ["unit_0"],
811+ ... }
805812 >>> model = ITE(treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
806813 >>> model.set_time_range(None, data)
807814 >>> model.fit(X=_X, y=_y, coords=COORDS)
@@ -909,8 +916,8 @@ def build_model(self, X, y, coords):
909916 )
910917 delta_t = pm .Deterministic (
911918 name = "delta_t" ,
912- var = (t - treatment_time )[:, None ] ,
913- dims = ["obs_ind" , "treated_units" ],
919+ var = (t - treatment_time ),
920+ dims = ["obs_ind" ],
914921 )
915922 beta = pm .Normal (
916923 name = "beta" ,
@@ -927,33 +934,28 @@ def build_model(self, X, y, coords):
927934 "level" ,
928935 mu = self .treatment_effect_param ["level" ][0 ],
929936 sigma = self .treatment_effect_param ["level" ][1 ],
930- dims = ["obs_ind" , "treated_units" ],
931937 )
932938 mu_in_components .append (level )
933939 if "trend" in self .treatment_effect_param :
934940 trend = pm .Normal (
935941 "trend" ,
936942 mu = self .treatment_effect_param ["trend" ][0 ],
937943 sigma = self .treatment_effect_param ["trend" ][1 ],
938- dims = ["obs_ind" , "treated_units" ],
939944 )
940945 mu_in_components .append (trend * delta_t )
941946 if "impulse" in self .treatment_effect_param :
942947 impulse_amplitude = pm .Normal (
943948 "impulse_amplitude" ,
944949 mu = self .treatment_effect_param ["impulse" ][0 ],
945950 sigma = self .treatment_effect_param ["impulse" ][1 ],
946- dims = ["obs_ind" , "treated_units" ],
947951 )
948952 decay_rate = pm .HalfNormal (
949953 "decay_rate" ,
950954 sigma = self .treatment_effect_param ["impulse" ][2 ],
951- dims = ["obs_ind" , "treated_units" ],
952955 )
953956 impulse = pm .Deterministic (
954957 "impulse" ,
955958 impulse_amplitude * pm .math .exp (- decay_rate * pm .math .abs (delta_t )),
956- dims = ["obs_ind" , "treated_units" ],
957959 )
958960 mu_in_components .append (impulse )
959961
@@ -968,18 +970,18 @@ def build_model(self, X, y, coords):
968970 pm .Deterministic (
969971 name = "mu_in" ,
970972 var = sum (mu_in_components ),
971- dims = ["obs_ind" , "treated_units" ],
972973 )
973974 if len (mu_in_components ) > 0
974975 else pm .Data (
975976 name = "mu_in" ,
976- vars = np .zeros ((X .sizes ["obs_ind" ], y .sizes ["treated_units" ])),
977- dims = ["obs_ind" , "treated_units" ],
977+ vars = 0 ,
978978 )
979979 )
980980 # Compute and store the sum of the base time series and the intervention's effect
981981 mu_ts = pm .Deterministic (
982- "mu_ts" , mu + weight * mu_in , dims = ["obs_ind" , "treated_units" ]
982+ "mu_ts" ,
983+ mu + (weight * mu_in )[:, None ],
984+ dims = ["obs_ind" , "treated_units" ],
983985 )
984986 sigma = pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
985987
@@ -1016,7 +1018,7 @@ def predict(self, X):
10161018 )
10171019
10181020 # TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter?
1019- if isinstance (X , xr .DataArray ):
1021+ if isinstance (X , xr .DataArray ) and "obs_ind" in X . coords :
10201022 pp ["posterior_predictive" ] = pp ["posterior_predictive" ].assign_coords (
10211023 obs_ind = X .obs_ind
10221024 )
0 commit comments