Skip to content

Commit e5ee32c

Browse files
committed
resolve conflicts
1 parent cf2a6f7 commit e5ee32c

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

causalpy/pymc_models.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -789,19 +789,26 @@ class InterventionTimeEstimator(PyMCModel):
789789
>>> labels = X.design_info.column_names
790790
>>> _y, _X = np.asarray(y), np.asarray(X)
791791
>>> _X = xr.DataArray(
792-
... _X,
793-
... dims=["obs_ind", "coeffs"],
794-
... coords={
795-
... "obs_ind": data.index,
796-
... "coeffs": labels,
797-
... },
792+
... _X,
793+
... dims=["obs_ind", "coeffs"],
794+
... coords={
795+
... "obs_ind": data.index,
796+
... "coeffs": labels,
797+
... },
798798
... )
799799
>>> _y = xr.DataArray(
800-
... _y[:, 0],
801-
... dims=["obs_ind"],
802-
... coords={"obs_ind": data.index},
803-
... )
804-
>>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
800+
... _y,
801+
... dims=["obs_ind", "treated_units"],
802+
... coords={
803+
... "obs_ind": data.index,
804+
... "treated_units": ["unit_0"]
805+
... },
806+
... )
807+
>>> COORDS = {
808+
... "coeffs": labels,
809+
... "obs_ind": np.arange(X.shape[0]),
810+
... "treated_units": ["unit_0"],
811+
... }
805812
>>> model = ITE(treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
806813
>>> model.set_time_range(None, data)
807814
>>> model.fit(X=_X, y=_y, coords=COORDS)
@@ -909,8 +916,8 @@ def build_model(self, X, y, coords):
909916
)
910917
delta_t = pm.Deterministic(
911918
name="delta_t",
912-
var=(t - treatment_time)[:, None],
913-
dims=["obs_ind", "treated_units"],
919+
var=(t - treatment_time),
920+
dims=["obs_ind"],
914921
)
915922
beta = pm.Normal(
916923
name="beta",
@@ -927,33 +934,28 @@ def build_model(self, X, y, coords):
927934
"level",
928935
mu=self.treatment_effect_param["level"][0],
929936
sigma=self.treatment_effect_param["level"][1],
930-
dims=["obs_ind", "treated_units"],
931937
)
932938
mu_in_components.append(level)
933939
if "trend" in self.treatment_effect_param:
934940
trend = pm.Normal(
935941
"trend",
936942
mu=self.treatment_effect_param["trend"][0],
937943
sigma=self.treatment_effect_param["trend"][1],
938-
dims=["obs_ind", "treated_units"],
939944
)
940945
mu_in_components.append(trend * delta_t)
941946
if "impulse" in self.treatment_effect_param:
942947
impulse_amplitude = pm.Normal(
943948
"impulse_amplitude",
944949
mu=self.treatment_effect_param["impulse"][0],
945950
sigma=self.treatment_effect_param["impulse"][1],
946-
dims=["obs_ind", "treated_units"],
947951
)
948952
decay_rate = pm.HalfNormal(
949953
"decay_rate",
950954
sigma=self.treatment_effect_param["impulse"][2],
951-
dims=["obs_ind", "treated_units"],
952955
)
953956
impulse = pm.Deterministic(
954957
"impulse",
955958
impulse_amplitude * pm.math.exp(-decay_rate * pm.math.abs(delta_t)),
956-
dims=["obs_ind", "treated_units"],
957959
)
958960
mu_in_components.append(impulse)
959961

@@ -968,18 +970,18 @@ def build_model(self, X, y, coords):
968970
pm.Deterministic(
969971
name="mu_in",
970972
var=sum(mu_in_components),
971-
dims=["obs_ind", "treated_units"],
972973
)
973974
if len(mu_in_components) > 0
974975
else pm.Data(
975976
name="mu_in",
976-
vars=np.zeros((X.sizes["obs_ind"], y.sizes["treated_units"])),
977-
dims=["obs_ind", "treated_units"],
977+
vars=0,
978978
)
979979
)
980980
# Compute and store the sum of the base time series and the intervention's effect
981981
mu_ts = pm.Deterministic(
982-
"mu_ts", mu + weight * mu_in, dims=["obs_ind", "treated_units"]
982+
"mu_ts",
983+
mu + (weight * mu_in)[:, None],
984+
dims=["obs_ind", "treated_units"],
983985
)
984986
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
985987

@@ -1016,7 +1018,7 @@ def predict(self, X):
10161018
)
10171019

10181020
# TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter?
1019-
if isinstance(X, xr.DataArray):
1021+
if isinstance(X, xr.DataArray) and "obs_ind" in X.coords:
10201022
pp["posterior_predictive"] = pp["posterior_predictive"].assign_coords(
10211023
obs_ind=X.obs_ind
10221024
)

0 commit comments

Comments
 (0)