Skip to content

Commit 4fdf92b

Browse files
committed
srm check finished
1 parent e912176 commit 4fdf92b

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

src/services/analytics/stat_functions.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def sample_size_ratio_metric(
620620
def sample_ratio_mismatch_test(
621621
observed_counts: list[int] | np.ndarray,
622622
expected_ratios: list[float] | np.ndarray | None = None,
623-
alpha: float = 1e-3,
623+
alpha: float = 1e-3,
624624
) -> SRMResult:
625625
"""Sample Ratio Mismatch (SRM) via Pearson's chi-square goodness-of-fit.
626626
@@ -652,17 +652,16 @@ def sample_ratio_mismatch_test(
652652
alloc = np.asarray(expected_ratios, dtype=float)
653653
if alloc.shape != (k,):
654654
raise ValueError("expected_ratios must have the same length as observed_counts")
655-
if np.any(alloc < 0):
656-
raise ValueError("All expected ratios must be non-negative")
657-
s = alloc.sum()
655+
if np.any(~np.isfinite(alloc)) or np.any(alloc <= 0):
656+
raise ValueError("All expected ratios must be finite and strictly positive")
657+
s = float(alloc.sum())
658658
if not np.isfinite(s) or s <= 0:
659659
raise ValueError("expected_ratios must sum to a positive number")
660-
if not np.isclose(alloc.sum(), 1):
661-
raise ValueError("expected_ratios must sum to 1")
662-
alloc = alloc / s
660+
if not np.isclose(s, 1.0, rtol=1e-6, atol=1e-12):
661+
raise ValueError("expected_ratios must sum to 1 (within tolerance)")
663662

664663
expected = N * alloc
665-
if np.any(expected == 0):
664+
if np.any(expected <= 0):
666665
raise ValueError("Expected counts contain zeros; ensure expected_ratios > 0 and N > 0")
667666

668667
# Pearson's chi-square
@@ -680,4 +679,4 @@ def sample_ratio_mismatch_test(
680679
observed=obs.astype(int),
681680
allocation=alloc.astype(float),
682681
is_srm=is_srm,
683-
)
682+
)

src/ui/results/elements.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ def render(observation_cnt: dict[str, Any]) -> None:
152152
Args:
153153
observation_cnt: A dictionary with observation counts per group.
154154
"""
155-
from src.services.analytics.stat_functions import sample_ratio_mismatch_test
156155
import pandas as pd
157156

157+
from src.services.analytics.stat_functions import sample_ratio_mismatch_test
158+
158159
with st.expander("SRM Check", expanded=False):
159160
placeholder = st.empty()
160161
if not observation_cnt or len(observation_cnt) < 2:
@@ -167,7 +168,7 @@ def render(observation_cnt: dict[str, Any]) -> None:
167168
options=groups_from_page,
168169
key="srm_check_control",
169170
default=groups_from_page,
170-
selection_mode="multi"
171+
selection_mode="multi",
171172
)
172173
if len(selected_groups) < 2:
173174
st.warning("Need at least 2 groups to perform SRM check")
@@ -197,23 +198,26 @@ def render(observation_cnt: dict[str, Any]) -> None:
197198
value=series.expected_ratio,
198199
step=0.01,
199200
key=f"srm_ratio_{series.group}",
200-
label_visibility="collapsed"
201+
label_visibility="collapsed",
201202
)
202203
expected_ratios.append(ratio_value)
203204

204-
205205
if st.button("🔍 Check for SRM", type="primary"):
206206
try:
207207
result = sample_ratio_mismatch_test(
208208
observed_counts=df.counts, expected_ratios=expected_ratios, alpha=1e-3
209209
)
210210
if result.is_srm:
211-
placeholder.error(f"Sample Ratio Mismatch detected! p-value: {result.p_value:.4f}", icon="🔥")
211+
placeholder.error(
212+
f"Sample Ratio Mismatch detected! p-value: {result.p_value:.4f}", icon="🔥"
213+
)
212214
else:
213-
placeholder.success(f"No Sample Ratio Mismatch detected. p-value: {result.p_value:.4f}", icon="✅")
215+
placeholder.success(
216+
f"No Sample Ratio Mismatch detected. p-value: {result.p_value:.4f}", icon="✅"
217+
)
214218

215219
except Exception as e:
216-
placeholder.error(f"Error running SRM test: {e}")
220+
placeholder.warning(f"Error running SRM test: {e}")
217221

218222

219223
class ResultsDataframes:
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
from src.services.analytics.stat_functions import sample_ratio_mismatch_test
4+
5+
6+
@pytest.mark.parametrize(
7+
"observed_counts, expected_ratios, expected_is_srm, p_value_threshold",
8+
[
9+
([1000, 1010], None, False, 0.001),
10+
([1000, 1200], None, True, 0.001),
11+
([100, 205], [1 / 3, 2 / 3], False, 0.001),
12+
([300, 700], [0.25, 0.75], True, 0.001),
13+
],
14+
)
15+
def test_sample_ratio_mismatch_happy_path(
16+
observed_counts, expected_ratios, expected_is_srm, p_value_threshold
17+
):
18+
result = sample_ratio_mismatch_test(observed_counts, expected_ratios, alpha=p_value_threshold)
19+
assert result.is_srm is expected_is_srm
20+
if expected_is_srm:
21+
assert result.p_value < p_value_threshold
22+
else:
23+
assert result.p_value >= p_value_threshold
24+
25+
26+
@pytest.mark.parametrize(
27+
"observed_counts, expected_ratios, error_message",
28+
[
29+
([100], None, "observed_counts must contain at least 2 groups"),
30+
([100, -10], None, "All observed counts must be non-negative"),
31+
([100, 100], [0.5, 0.4, 0.1], "expected_ratios must have the same length as observed_counts"),
32+
([100, 100], [0.6, 0.6], "expected_ratios must sum to 1"),
33+
([0, 0], None, "Total count must be positive"),
34+
],
35+
)
36+
def test_sample_ratio_mismatch_corner_cases(observed_counts, expected_ratios, error_message):
37+
"""Tests corner cases and invalid inputs for sample_ratio_mismatch_test."""
38+
with pytest.raises(ValueError, match=error_message):
39+
sample_ratio_mismatch_test(observed_counts, expected_ratios)

0 commit comments

Comments
 (0)