Skip to content

Commit c437c91

Browse files
committed
increase test coverage
1 parent 6da90fe commit c437c91

File tree

2 files changed

+344
-3
lines changed

2 files changed

+344
-3
lines changed

causalpy/tests/test_transfer_function_its.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,3 +1477,344 @@ def test_treatment_only_adstock(self):
14771477
assert treatment.saturation is None
14781478
assert treatment.adstock is not None
14791479
assert treatment.lag is None
1480+
1481+
1482+
class TestBuildTreatmentMatrix:
1483+
"""Test _build_treatment_matrix internal method."""
1484+
1485+
def test_build_treatment_matrix_saturation_adstock(self):
1486+
"""Test _build_treatment_matrix with saturation and adstock."""
1487+
np.random.seed(42)
1488+
n = 50
1489+
t = np.arange(n)
1490+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1491+
1492+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1493+
treatment_raw = np.maximum(treatment_raw, 0)
1494+
y = 100.0 + 0.5 * t + np.random.normal(0, 5, n)
1495+
1496+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1497+
df = df.set_index("date")
1498+
1499+
model = TransferFunctionOLS(
1500+
saturation_type="hill",
1501+
saturation_grid={"slope": [2.0], "kappa": [50]},
1502+
adstock_grid={"half_life": [3], "l_max": [12], "normalize": [True]},
1503+
estimation_method="grid",
1504+
error_model="hac",
1505+
)
1506+
1507+
result = GradedInterventionTimeSeries(
1508+
data=df,
1509+
y_column="y",
1510+
treatment_names=["treatment"],
1511+
base_formula="1 + t",
1512+
model=model,
1513+
)
1514+
1515+
# Test the internal method
1516+
treatments = result.treatments
1517+
Z, labels = result._build_treatment_matrix(df, treatments)
1518+
1519+
assert Z.shape == (n, 1)
1520+
assert labels == ["treatment"]
1521+
assert not np.array_equal(Z.flatten(), treatment_raw) # Should be transformed
1522+
1523+
def test_build_treatment_matrix_single_transform(self):
1524+
"""Test _build_treatment_matrix with only adstock."""
1525+
np.random.seed(42)
1526+
n = 50
1527+
t = np.arange(n)
1528+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1529+
1530+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1531+
treatment_raw = np.maximum(treatment_raw, 0)
1532+
y = 100.0 + 0.5 * t + np.random.normal(0, 5, n)
1533+
1534+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1535+
df = df.set_index("date")
1536+
1537+
model = TransferFunctionOLS(
1538+
saturation_type=None,
1539+
adstock_grid={"half_life": [3], "l_max": [12], "normalize": [True]},
1540+
estimation_method="grid",
1541+
error_model="hac",
1542+
)
1543+
1544+
result = GradedInterventionTimeSeries(
1545+
data=df,
1546+
y_column="y",
1547+
treatment_names=["treatment"],
1548+
base_formula="1 + t",
1549+
model=model,
1550+
)
1551+
1552+
# Test the internal method
1553+
treatments = result.treatments
1554+
Z, labels = result._build_treatment_matrix(df, treatments)
1555+
1556+
assert Z.shape == (n, 1)
1557+
assert labels == ["treatment"]
1558+
1559+
1560+
class TestPlotIRFEdgeCases:
1561+
"""Test plot_irf edge cases and error handling."""
1562+
1563+
def test_plot_irf_invalid_channel(self):
1564+
"""Test plot_irf with invalid channel name."""
1565+
np.random.seed(42)
1566+
n = 50
1567+
t = np.arange(n)
1568+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1569+
1570+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1571+
treatment_raw = np.maximum(treatment_raw, 0)
1572+
y = 100.0 + 0.5 * t + np.random.normal(0, 5, n)
1573+
1574+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1575+
df = df.set_index("date")
1576+
1577+
model = TransferFunctionOLS(
1578+
saturation_type=None,
1579+
adstock_grid={"half_life": [3]},
1580+
estimation_method="grid",
1581+
error_model="hac",
1582+
)
1583+
1584+
result = GradedInterventionTimeSeries(
1585+
data=df,
1586+
y_column="y",
1587+
treatment_names=["treatment"],
1588+
base_formula="1 + t",
1589+
model=model,
1590+
)
1591+
1592+
with pytest.raises(ValueError, match="Channel.*not found"):
1593+
result.plot_irf("nonexistent_channel")
1594+
1595+
1596+
class TestSummaryMethod:
1597+
"""Test summary() method edge cases."""
1598+
1599+
def test_summary_with_arimax(self, capsys):
1600+
"""Test summary() with ARIMAX error model."""
1601+
np.random.seed(42)
1602+
n = 100
1603+
t = np.arange(n)
1604+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1605+
1606+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1607+
treatment_raw = np.maximum(treatment_raw, 0)
1608+
y = 100.0 + 0.5 * t + treatment_raw + np.random.normal(0, 5, n)
1609+
1610+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1611+
df = df.set_index("date")
1612+
1613+
model = TransferFunctionOLS(
1614+
saturation_type=None,
1615+
adstock_grid={"half_life": [3]},
1616+
estimation_method="grid",
1617+
error_model="arimax",
1618+
arima_order=(1, 0, 0),
1619+
)
1620+
1621+
result = GradedInterventionTimeSeries(
1622+
data=df,
1623+
y_column="y",
1624+
treatment_names=["treatment"],
1625+
base_formula="1 + t",
1626+
model=model,
1627+
)
1628+
1629+
result.summary(round_to=3)
1630+
1631+
captured = capsys.readouterr()
1632+
assert "ARIMAX" in captured.out
1633+
assert "ARIMA order" in captured.out
1634+
assert "(1, 0, 0)" in captured.out
1635+
1636+
def test_summary_custom_round_to(self, capsys):
1637+
"""Test summary() with custom round_to parameter."""
1638+
np.random.seed(42)
1639+
n = 50
1640+
t = np.arange(n)
1641+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1642+
1643+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1644+
treatment_raw = np.maximum(treatment_raw, 0)
1645+
y = 100.0 + 0.5 * t + treatment_raw + np.random.normal(0, 5, n)
1646+
1647+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1648+
df = df.set_index("date")
1649+
1650+
model = TransferFunctionOLS(
1651+
saturation_type=None,
1652+
adstock_grid={"half_life": [3]},
1653+
estimation_method="grid",
1654+
error_model="hac",
1655+
)
1656+
1657+
result = GradedInterventionTimeSeries(
1658+
data=df,
1659+
y_column="y",
1660+
treatment_names=["treatment"],
1661+
base_formula="1 + t",
1662+
model=model,
1663+
)
1664+
1665+
result.summary(round_to=4)
1666+
1667+
captured = capsys.readouterr()
1668+
assert "Graded Intervention Time Series Results" in captured.out
1669+
1670+
1671+
class TestModelTypeValidation:
1672+
"""Test validation of model types."""
1673+
1674+
def test_invalid_model_type_raises_error(self):
1675+
"""Test that invalid model type raises ValueError."""
1676+
np.random.seed(42)
1677+
n = 50
1678+
t = np.arange(n)
1679+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1680+
1681+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1682+
y = 100.0 + 0.5 * t + treatment_raw + np.random.normal(0, 5, n)
1683+
1684+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1685+
df = df.set_index("date")
1686+
1687+
# Use an invalid model (just a string)
1688+
with pytest.raises(ValueError, match="Model type not recognized"):
1689+
GradedInterventionTimeSeries(
1690+
data=df,
1691+
y_column="y",
1692+
treatment_names=["treatment"],
1693+
base_formula="1 + t",
1694+
model="invalid_model",
1695+
)
1696+
1697+
1698+
class TestRangeIndexSupport:
1699+
"""Test that integer RangeIndex is supported."""
1700+
1701+
def test_range_index_works(self):
1702+
"""Test that RangeIndex is accepted as valid index."""
1703+
np.random.seed(42)
1704+
n = 50
1705+
t = np.arange(n)
1706+
1707+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1708+
treatment_raw = np.maximum(treatment_raw, 0)
1709+
y = 100.0 + 0.5 * t + treatment_raw + np.random.normal(0, 5, n)
1710+
1711+
df = pd.DataFrame({"t": t, "y": y, "treatment": treatment_raw})
1712+
# RangeIndex is the default for DataFrame without explicit index
1713+
assert isinstance(df.index, pd.RangeIndex)
1714+
1715+
model = TransferFunctionOLS(
1716+
saturation_type=None,
1717+
adstock_grid={"half_life": [3]},
1718+
estimation_method="grid",
1719+
error_model="hac",
1720+
)
1721+
1722+
# Should not raise an error
1723+
result = GradedInterventionTimeSeries(
1724+
data=df,
1725+
y_column="y",
1726+
treatment_names=["treatment"],
1727+
base_formula="1 + t",
1728+
model=model,
1729+
)
1730+
1731+
assert result.ols_result is not None
1732+
1733+
def test_integer_index_works(self):
1734+
"""Test that explicit integer Index is accepted."""
1735+
np.random.seed(42)
1736+
n = 50
1737+
t = np.arange(n)
1738+
1739+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1740+
treatment_raw = np.maximum(treatment_raw, 0)
1741+
y = 100.0 + 0.5 * t + treatment_raw + np.random.normal(0, 5, n)
1742+
1743+
df = pd.DataFrame({"t": t, "y": y, "treatment": treatment_raw})
1744+
df.index = pd.Index(range(n)) # Explicit integer Index
1745+
assert isinstance(df.index, pd.Index)
1746+
assert pd.api.types.is_integer_dtype(df.index)
1747+
1748+
model = TransferFunctionOLS(
1749+
saturation_type=None,
1750+
adstock_grid={"half_life": [3]},
1751+
estimation_method="grid",
1752+
error_model="hac",
1753+
)
1754+
1755+
# Should not raise an error
1756+
result = GradedInterventionTimeSeries(
1757+
data=df,
1758+
y_column="y",
1759+
treatment_names=["treatment"],
1760+
base_formula="1 + t",
1761+
model=model,
1762+
)
1763+
1764+
assert result.ols_result is not None
1765+
1766+
1767+
class TestEffectWithARIMAX:
1768+
"""Test effect() method with ARIMAX error model."""
1769+
1770+
def test_effect_with_arimax_model(self):
1771+
"""Test that effect() works correctly with ARIMAX."""
1772+
np.random.seed(42)
1773+
n = 100
1774+
t = np.arange(n)
1775+
dates = pd.date_range("2020-01-01", periods=n, freq="W")
1776+
1777+
treatment_raw = 50 + np.random.uniform(-10, 10, n)
1778+
treatment_raw = np.maximum(treatment_raw, 0)
1779+
1780+
# Create AR(1) errors
1781+
rho = 0.5
1782+
errors = np.zeros(n)
1783+
errors[0] = np.random.normal(0, 10 / np.sqrt(1 - rho**2))
1784+
for i in range(1, n):
1785+
errors[i] = rho * errors[i - 1] + np.random.normal(0, 10)
1786+
1787+
y = 100.0 + 0.5 * t + 50 * treatment_raw + errors
1788+
1789+
df = pd.DataFrame({"date": dates, "t": t, "y": y, "treatment": treatment_raw})
1790+
df = df.set_index("date")
1791+
1792+
model = TransferFunctionOLS(
1793+
saturation_type=None,
1794+
adstock_grid={"half_life": [3]},
1795+
estimation_method="grid",
1796+
error_model="arimax",
1797+
arima_order=(1, 0, 0),
1798+
)
1799+
1800+
result = GradedInterventionTimeSeries(
1801+
data=df,
1802+
y_column="y",
1803+
treatment_names=["treatment"],
1804+
base_formula="1 + t",
1805+
model=model,
1806+
)
1807+
1808+
# Test effect
1809+
effect_result = result.effect(
1810+
window=(df.index[0], df.index[-1]), channels=None, scale=0.0
1811+
)
1812+
1813+
assert "effect_df" in effect_result
1814+
assert "total_effect" in effect_result
1815+
assert effect_result["total_effect"] != 0 # Should have nonzero effect
1816+
1817+
# Test plot_effect
1818+
fig, ax = result.plot_effect(effect_result)
1819+
assert isinstance(fig, plt.Figure)
1820+
plt.close(fig)

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)