Skip to content

Commit 2d31504

Browse files
committed
get tests passing
1 parent ec98c8a commit 2d31504

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

causalpy/tests/test_transfer_function_its.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,17 +198,27 @@ def test_grid_search_basic(self):
198198
t = np.arange(n)
199199
dates = pd.date_range("2020-01-01", periods=n, freq="W")
200200

201-
# Generate treatment with known transforms
202-
treatment_raw = np.random.uniform(0, 100, n)
201+
# Generate treatment with known transforms - use more varied signal
202+
treatment_raw = (
203+
50 + 30 * np.sin(2 * np.pi * t / 20) + np.random.uniform(-10, 10, n)
204+
)
205+
treatment_raw = np.maximum(treatment_raw, 0) # Keep non-negative
206+
203207
sat = HillSaturation(slope=2.0, kappa=50)
204208
treatment_sat = sat.apply(treatment_raw)
205209
adstock = GeometricAdstock(half_life=3.0, normalize=True)
206210
treatment_transformed = adstock.apply(treatment_sat)
207211

208-
# Generate outcome
212+
# Generate outcome with stronger signal and time trend
209213
beta_0 = 100.0
210-
theta = 5.0
211-
y = beta_0 + theta * treatment_transformed + np.random.normal(0, 2, n)
214+
beta_t = 0.5
215+
theta = 50.0 # Stronger treatment effect
216+
y = (
217+
beta_0
218+
+ beta_t * t
219+
+ theta * treatment_transformed
220+
+ np.random.normal(0, 5, n)
221+
)
212222

213223
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
214224
df = df.set_index("date")
@@ -219,7 +229,7 @@ def test_grid_search_basic(self):
219229
data=df,
220230
y_column="y",
221231
treatment_name="treatment",
222-
base_formula="1",
232+
base_formula="1 + t", # Include time trend
223233
estimation_method="grid",
224234
saturation_type="hill",
225235
saturation_grid={"slope": [1.5, 2.0, 2.5], "kappa": [40, 50, 60]},
@@ -249,17 +259,27 @@ def test_optimize_basic(self):
249259
t = np.arange(n)
250260
dates = pd.date_range("2020-01-01", periods=n, freq="W")
251261

252-
# Generate treatment with known transforms
253-
treatment_raw = np.random.uniform(0, 100, n)
262+
# Generate treatment with known transforms - use more varied signal
263+
treatment_raw = (
264+
50 + 30 * np.sin(2 * np.pi * t / 20) + np.random.uniform(-10, 10, n)
265+
)
266+
treatment_raw = np.maximum(treatment_raw, 0) # Keep non-negative
267+
254268
sat = HillSaturation(slope=2.0, kappa=50)
255269
treatment_sat = sat.apply(treatment_raw)
256270
adstock = GeometricAdstock(half_life=3.0, normalize=True)
257271
treatment_transformed = adstock.apply(treatment_sat)
258272

259-
# Generate outcome
273+
# Generate outcome with stronger signal and time trend
260274
beta_0 = 100.0
261-
theta = 5.0
262-
y = beta_0 + theta * treatment_transformed + np.random.normal(0, 2, n)
275+
beta_t = 0.5
276+
theta = 50.0 # Stronger treatment effect
277+
y = (
278+
beta_0
279+
+ beta_t * t
280+
+ theta * treatment_transformed
281+
+ np.random.normal(0, 5, n)
282+
)
263283

264284
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
265285
df = df.set_index("date")
@@ -269,7 +289,7 @@ def test_optimize_basic(self):
269289
data=df,
270290
y_column="y",
271291
treatment_name="treatment",
272-
base_formula="1",
292+
base_formula="1 + t", # Include time trend
273293
estimation_method="optimize",
274294
saturation_type="hill",
275295
saturation_bounds={"slope": (1.0, 4.0), "kappa": (20, 100)},

0 commit comments

Comments
 (0)