Skip to content

Commit aac61d3

Browse files
Feature/result page (#37)
* srtart adding srm and groupd_filter to result page * changed pages design and elements layouts * intermediate changes * final commit for toggle functionality * add tests * start working on srm functionality * srm check ui finished * srm check finished
1 parent 7344860 commit aac61d3

File tree

3 files changed

+176
-2
lines changed

3 files changed

+176
-2
lines changed

src/services/analytics/stat_functions.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121
from scipy.optimize import root_scalar # type: ignore
2222
from scipy.special import ndtr, stdtr, stdtrit # type: ignore
23-
from scipy.stats import norm, t # type: ignore
23+
from scipy.stats import chi2, norm, t # type: ignore
2424
from statsmodels.stats.power import TTestIndPower # type: ignore
2525

2626
from src.services.analytics._constants import NORM_PPF_ALPHA_TWO_SIDED, NORM_PPF_BETA
@@ -29,6 +29,9 @@
2929
# Define types for result
3030
ConfidenceInterval = namedtuple("ConfidenceInterval", ["lower", "upper"])
3131
TestResult = namedtuple("TestResult", ["statistic", "p_value", "ci", "diff_abs", "diff_ratio"])
32+
SRMResult = namedtuple(
33+
"SRMResult", ["statistic", "p_value", "df", "expected", "observed", "allocation", "is_srm"]
34+
)
3235

3336

3437
def ttest_welch(
@@ -609,3 +612,71 @@ def sample_size_ratio_metric(
609612

610613
n_required = 2 * (z_alpha + z_beta) ** 2 * ratio_var / (diff**2)
611614
return np.ceil(n_required)
615+
616+
617+
# -------------------------- SAMPLE RATIO MISMATCH FUNCTION -------------------------- #
618+
619+
620+
def sample_ratio_mismatch_test(
621+
observed_counts: list[int] | np.ndarray,
622+
expected_ratios: list[float] | np.ndarray | None = None,
623+
alpha: float = 1e-3,
624+
) -> SRMResult:
625+
"""Sample Ratio Mismatch (SRM) via Pearson's chi-square goodness-of-fit.
626+
627+
Args:
628+
observed_counts: 1-D array-like of non-negative integers per arm.
629+
expected_ratios: 1-D array-like of expected proportions (sum ~ 1).
630+
If None, assumes uniform split.
631+
alpha: Significance level for flagging SRM.
632+
633+
Returns:
634+
SRMResult(statistic, p_value, df, expected_counts, observed_counts, allocation, is_srm)
635+
"""
636+
obs = np.asarray(observed_counts, dtype=float)
637+
if obs.ndim != 1:
638+
raise ValueError("observed_counts must be 1-D")
639+
if obs.size < 2:
640+
raise ValueError("observed_counts must contain at least 2 groups")
641+
if np.any(obs < 0):
642+
raise ValueError("All observed counts must be non-negative")
643+
644+
N = obs.sum()
645+
if not np.isfinite(N) or N <= 0:
646+
raise ValueError("Total count must be positive")
647+
648+
k = obs.size
649+
if expected_ratios is None:
650+
alloc = np.full(k, 1.0 / k, dtype=float)
651+
else:
652+
alloc = np.asarray(expected_ratios, dtype=float)
653+
if alloc.shape != (k,):
654+
raise ValueError("expected_ratios must have the same length as observed_counts")
655+
if np.any(~np.isfinite(alloc)) or np.any(alloc <= 0):
656+
raise ValueError("All expected ratios must be finite and strictly positive")
657+
s = float(alloc.sum())
658+
if not np.isfinite(s) or s <= 0:
659+
raise ValueError("expected_ratios must sum to a positive number")
660+
if not np.isclose(s, 1.0, rtol=1e-6, atol=1e-12):
661+
raise ValueError("expected_ratios must sum to 1 (within tolerance)")
662+
663+
expected = N * alloc
664+
if np.any(expected <= 0):
665+
raise ValueError("Expected counts contain zeros; ensure expected_ratios > 0 and N > 0")
666+
667+
# Pearson's chi-square
668+
stat = np.sum((obs - expected) ** 2 / expected)
669+
df = k - 1
670+
p = chi2.sf(stat, df)
671+
672+
is_srm = bool(p < alpha)
673+
674+
return SRMResult(
675+
statistic=float(stat),
676+
p_value=float(p),
677+
df=int(df),
678+
expected=expected.astype(float),
679+
observed=obs.astype(int),
680+
allocation=alloc.astype(float),
681+
is_srm=is_srm,
682+
)

src/ui/results/elements.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,72 @@ def render(observation_cnt: dict[str, Any]) -> None:
152152
Args:
153153
observation_cnt: A dictionary with observation counts per group.
154154
"""
155+
import pandas as pd
156+
157+
from src.services.analytics.stat_functions import sample_ratio_mismatch_test
158+
155159
with st.expander("SRM Check", expanded=False):
156-
st.markdown("place for srm")
160+
placeholder = st.empty()
161+
if not observation_cnt or len(observation_cnt) < 2:
162+
st.warning("Need at least 2 groups to perform SRM check")
163+
return
164+
165+
groups_from_page = observation_cnt.keys()
166+
selected_groups = st.segmented_control(
167+
label="Select groups",
168+
options=groups_from_page,
169+
key="srm_check_control",
170+
default=groups_from_page,
171+
selection_mode="multi",
172+
)
173+
if len(selected_groups) < 2:
174+
st.warning("Need at least 2 groups to perform SRM check")
175+
return
176+
177+
obs_items = [(k, int(v)) for k, v in observation_cnt.items() if k in selected_groups]
178+
df = pd.DataFrame(obs_items, columns=["group", "counts"])
179+
total_counts = df.counts.sum()
180+
df["current_ratio"] = df.counts / total_counts
181+
df["expected_ratio"] = 1 / df.shape[0]
182+
183+
col1, col2, col3 = st.columns([1, 1, 2])
184+
expected_ratios = []
185+
col1.markdown("**Group**")
186+
col2.markdown("**Counts**")
187+
col3.markdown("**Expected**")
188+
for row in df.iterrows():
189+
series = row[1]
190+
col1, col2, col3 = st.columns([1, 1, 2])
191+
col1.write(series.group)
192+
col2.text(f"{int(series.counts)}")
193+
with col3:
194+
ratio_value = st.number_input(
195+
"Ratio",
196+
min_value=0.001,
197+
max_value=1.0,
198+
value=series.expected_ratio,
199+
step=0.01,
200+
key=f"srm_ratio_{series.group}",
201+
label_visibility="collapsed",
202+
)
203+
expected_ratios.append(ratio_value)
204+
205+
if st.button("🔍 Check for SRM", type="primary"):
206+
try:
207+
result = sample_ratio_mismatch_test(
208+
observed_counts=df.counts, expected_ratios=expected_ratios, alpha=1e-3
209+
)
210+
if result.is_srm:
211+
placeholder.error(
212+
f"Sample Ratio Mismatch detected! p-value: {result.p_value:.4f}", icon="🔥"
213+
)
214+
else:
215+
placeholder.success(
216+
f"No Sample Ratio Mismatch detected. p-value: {result.p_value:.4f}", icon="✅"
217+
)
218+
219+
except Exception as e:
220+
placeholder.warning(f"Error running SRM test: {e}")
157221

158222

159223
class ResultsDataframes:
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
from src.services.analytics.stat_functions import sample_ratio_mismatch_test
4+
5+
6+
@pytest.mark.parametrize(
7+
"observed_counts, expected_ratios, expected_is_srm, p_value_threshold",
8+
[
9+
([1000, 1010], None, False, 0.001),
10+
([1000, 1200], None, True, 0.001),
11+
([100, 205], [1 / 3, 2 / 3], False, 0.001),
12+
([300, 700], [0.25, 0.75], True, 0.001),
13+
],
14+
)
15+
def test_sample_ratio_mismatch_happy_path(
16+
observed_counts, expected_ratios, expected_is_srm, p_value_threshold
17+
):
18+
result = sample_ratio_mismatch_test(observed_counts, expected_ratios, alpha=p_value_threshold)
19+
assert result.is_srm is expected_is_srm
20+
if expected_is_srm:
21+
assert result.p_value < p_value_threshold
22+
else:
23+
assert result.p_value >= p_value_threshold
24+
25+
26+
@pytest.mark.parametrize(
27+
"observed_counts, expected_ratios, error_message",
28+
[
29+
([100], None, "observed_counts must contain at least 2 groups"),
30+
([100, -10], None, "All observed counts must be non-negative"),
31+
([100, 100], [0.5, 0.4, 0.1], "expected_ratios must have the same length as observed_counts"),
32+
([100, 100], [0.6, 0.6], "expected_ratios must sum to 1"),
33+
([0, 0], None, "Total count must be positive"),
34+
],
35+
)
36+
def test_sample_ratio_mismatch_corner_cases(observed_counts, expected_ratios, error_message):
37+
"""Tests corner cases and invalid inputs for sample_ratio_mismatch_test."""
38+
with pytest.raises(ValueError, match=error_message):
39+
sample_ratio_mismatch_test(observed_counts, expected_ratios)

0 commit comments

Comments
 (0)