Skip to content

Commit f94b197

Browse files
committed
Use dedicated sampling parameters for AR model test
Introduces a separate 'ar_sample_kwargs' dictionary for Bayesian AR(1) model sampling in the transfer function test. Updates assertions to reference the new parameters, clarifying the need for increased sampling due to model complexity.
1 parent 87a5002 commit f94b197

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,14 +1203,22 @@ def test_transfer_function_ar_bayesian(mock_pymc_sample):
12031203
)
12041204

12051205
# Fit Bayesian AR(1) model
1206+
# Note: AR models need more sampling than independent errors models due to complexity
1207+
ar_sample_kwargs = {
1208+
"tune": 50,
1209+
"draws": 50,
1210+
"chains": 2,
1211+
"cores": 2,
1212+
"random_seed": 42,
1213+
}
12061214
model = cp.pymc_models.TransferFunctionARRegression(
12071215
saturation_type=None,
12081216
adstock_config={
12091217
"half_life_prior": {"dist": "Gamma", "alpha": 4, "beta": 2},
12101218
"l_max": 8,
12111219
"normalize": True,
12121220
},
1213-
sample_kwargs=sample_kwargs,
1221+
sample_kwargs=ar_sample_kwargs,
12141222
)
12151223

12161224
result = cp.GradedInterventionTimeSeries(
@@ -1224,8 +1232,10 @@ def test_transfer_function_ar_bayesian(mock_pymc_sample):
12241232
# Test basic properties
12251233
assert isinstance(result, cp.GradedInterventionTimeSeries)
12261234
assert hasattr(result.model, "idata")
1227-
assert len(result.model.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
1228-
assert len(result.model.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
1235+
assert (
1236+
len(result.model.idata.posterior.coords["chain"]) == ar_sample_kwargs["chains"]
1237+
)
1238+
assert len(result.model.idata.posterior.coords["draw"]) == ar_sample_kwargs["draws"]
12291239

12301240
# Test that transform parameters are in posterior
12311241
assert "half_life" in result.model.idata.posterior

0 commit comments

Comments
 (0)