|
1 | 1 | import unittest |
2 | 2 | import numpy as np |
3 | | -from sklearn.linear_model import LogisticRegression |
| 3 | +from sklearn.linear_model import LogisticRegression, LinearRegression |
4 | 4 | from dte_adj import SimpleLocalDistributionEstimator, AdjustedLocalDistributionEstimator |
5 | 5 |
|
| 6 | +np.random.seed(123) |
| 7 | + |
| 8 | + |
| 9 | +def generate_data(n=1000, S=4): |
| 10 | + # Generate W ~ U(0,1) |
| 11 | + W = np.random.uniform(0, 1, n) |
| 12 | + |
| 13 | + # Assign strata based on W |
| 14 | + strata = np.digitize(W, np.linspace(0, 1, S + 1)[1:]) |
| 15 | + |
| 16 | + # Generate X ~ N(0, I_20) |
| 17 | + X = np.random.randn(n, 20) |
| 18 | + |
| 19 | + # Treatment assignment Z ~ Bernoulli(0.5) within each stratum |
| 20 | + Z = np.zeros(n) |
| 21 | + for s in range(S): |
| 22 | + indices = np.where(strata == s)[0] |
| 23 | + Z[indices] = np.random.binomial(1, 0.5, size=len(indices)) |
| 24 | + |
| 25 | + # Define functions b(X, W) and c(X, W) |
| 26 | + def b(X, W): |
| 27 | + return ( |
| 28 | + np.sin(np.pi * X[:, 0] * X[:, 1]) |
| 29 | + + 2 * (X[:, 2] - 0.5) ** 2 |
| 30 | + + X[:, 3] |
| 31 | + + 0.5 * X[:, 4] |
| 32 | + + 0.1 * W |
| 33 | + ) |
| 34 | + |
| 35 | + def c(X, W): |
| 36 | + return 0.1 * (X[:, 0] + np.log(1 + np.exp(X[:, 1])) + W) |
| 37 | + |
| 38 | + # Define parameters |
| 39 | + a1, a0 = 4, 1 |
| 40 | + b1, b0 = 1, -1 |
| 41 | + c1, c0 = 3, 3 |
| 42 | + |
| 43 | + # Generate errors |
| 44 | + epsilon = np.random.randn(n) |
| 45 | + |
| 46 | + # Compute Y(d) |
| 47 | + Y0 = a0 + b(X, W) + epsilon |
| 48 | + Y1 = a1 + b(X, W) + epsilon |
| 49 | + |
| 50 | + # Compute D(0) and D(1) |
| 51 | + D0 = (b0 + c(X, W) > c0 * epsilon).astype(int) |
| 52 | + D1 = np.where(D0 == 0, (b1 + c(X, W) > c1 * epsilon).astype(int), 1) |
| 53 | + |
| 54 | + # Compute observed D and Y |
| 55 | + D = D1 * Z + D0 * (1 - Z) |
| 56 | + Y = Y1 * D + Y0 * (1 - D) |
| 57 | + |
| 58 | + # discrete |
| 59 | + Y = np.random.poisson(np.abs(Y)) |
| 60 | + |
| 61 | + return { |
| 62 | + "W": W, |
| 63 | + "X": X, |
| 64 | + "Z": Z, |
| 65 | + "D": D, |
| 66 | + "Y": Y, |
| 67 | + "strata": strata, |
| 68 | + } |
| 69 | + |
6 | 70 |
|
7 | 71 | class TestLocalEstimators(unittest.TestCase): |
8 | 72 | def setUp(self): |
@@ -232,3 +296,43 @@ def test_adjusted_local_estimator_predict_lpte(self): |
232 | 296 | self.assertTrue(np.all(lower_bound <= upper_bound)) |
233 | 297 | self.assertTrue(np.all(lower_bound <= beta)) |
234 | 298 | self.assertTrue(np.all(beta <= upper_bound)) |
| 299 | + |
| 300 | + |
| 301 | +class TestE2E(unittest.TestCase): |
| 302 | + def test_e2e(self): |
| 303 | + # Arrange |
| 304 | + data = generate_data(n=3000) |
| 305 | + X, D, Y, Z, S = data["X"], data["W"], data["Y"], data["Z"], data["strata"] |
| 306 | + locations = np.array([np.percentile(Y, p) for p in range(10, 91, 10)]) |
| 307 | + simple_estimator = SimpleLocalDistributionEstimator() |
| 308 | + adjusted_estimator = AdjustedLocalDistributionEstimator(LinearRegression()) |
| 309 | + |
| 310 | + # Act |
| 311 | + simple_estimator.fit(X, Z, D, Y, S) |
| 312 | + adjusted_estimator.fit(X, Z, D, Y, S) |
| 313 | + |
| 314 | + simple_dte, simple_lower_bound, simple_upper_bound = ( |
| 315 | + simple_estimator.predict_dte(1, 0, locations) |
| 316 | + ) |
| 317 | + adjusted_dte, adjusted_lower_bound, adjusted_upper_bound = ( |
| 318 | + adjusted_estimator.predict_dte(1, 0, locations) |
| 319 | + ) |
| 320 | + |
| 321 | + # Assert |
| 322 | + np.testing.assert_(np.all(simple_dte < 0), "Not all values are negative") |
| 323 | + np.testing.assert_(np.all(adjusted_dte < 0), "Not all values are negative") |
| 324 | + np.testing.assert_( |
| 325 | + np.all(simple_lower_bound < simple_upper_bound), |
| 326 | + "Upper bound is less than lower bound", |
| 327 | + ) |
| 328 | + np.testing.assert_( |
| 329 | + np.all(adjusted_lower_bound < adjusted_upper_bound), |
| 330 | + "Upper bound is less than lower bound", |
| 331 | + ) |
| 332 | + np.testing.assert_( |
| 333 | + np.all( |
| 334 | + adjusted_upper_bound - adjusted_lower_bound |
| 335 | + < simple_upper_bound - simple_lower_bound |
| 336 | + ), |
| 337 | + "Adjusted estimator does not have narrower intervals", |
| 338 | + ) |
0 commit comments