Skip to content

Commit d83664f

Browse files
committed
quick changes
1 parent aa2ff9f commit d83664f

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

causalpy/pymc_models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,10 +1064,14 @@ def build_model(
10641064

10651065
# Time data for trend and seasonality
10661066
t_trend_data = pm.Data(
1067-
"t_trend_data", time_for_trend, dims="obs_ind", mutable=True
1067+
"t_trend_data",
1068+
time_for_trend,
1069+
dims="obs_ind",
10681070
)
10691071
t_season_data = pm.Data(
1070-
"t_season_data", time_for_seasonality, dims="obs_ind", mutable=True
1072+
"t_season_data",
1073+
time_for_seasonality,
1074+
dims="obs_ind",
10711075
)
10721076

10731077
# Get validated components (no more ugly imports in build_model!)
@@ -1114,8 +1118,8 @@ def build_model(
11141118
f"Shape mismatch: X_values_for_pymc has {X_values_for_pymc.shape[1]} columns, but "
11151119
f"{len(self._exog_var_names)} names in self._exog_var_names ({self._exog_var_names})."
11161120
)
1117-
X_data = pm.Data(
1118-
"X", X_values_for_pymc, dims=["obs_ind", "coeffs"], mutable=True
1121+
X_data = pm.MutableData(
1122+
"X", X_values_for_pymc, dims=["obs_ind", "coeffs"]
11191123
)
11201124
beta = pm.Normal("beta", mu=0, sigma=10, dims="coeffs")
11211125
mu_ = mu_ + pm.math.dot(X_data, beta)
@@ -1125,7 +1129,7 @@ def build_model(
11251129

11261130
# Likelihood
11271131
sigma = pm.HalfNormal("sigma", sigma=self.prior_sigma)
1128-
y_data = pm.Data("y", y.flatten(), dims="obs_ind", mutable=True)
1132+
y_data = pm.MutableData("y", y.flatten(), dims="obs_ind")
11291133
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_data, dims="obs_ind")
11301134

11311135
def fit(

0 commit comments

Comments
 (0)