Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/services/analytics/stat_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
from scipy.optimize import root_scalar # type: ignore
from scipy.special import ndtr, stdtr, stdtrit # type: ignore
from scipy.stats import norm, t # type: ignore
from scipy.stats import chi2, norm, t # type: ignore
from statsmodels.stats.power import TTestIndPower # type: ignore

from src.services.analytics._constants import NORM_PPF_ALPHA_TWO_SIDED, NORM_PPF_BETA
Expand All @@ -29,6 +29,9 @@
# Define types for result
ConfidenceInterval = namedtuple("ConfidenceInterval", ["lower", "upper"])
TestResult = namedtuple("TestResult", ["statistic", "p_value", "ci", "diff_abs", "diff_ratio"])
SRMResult = namedtuple(
"SRMResult", ["statistic", "p_value", "df", "expected", "observed", "allocation", "is_srm"]
)


def ttest_welch(
Expand Down Expand Up @@ -609,3 +612,71 @@ def sample_size_ratio_metric(

n_required = 2 * (z_alpha + z_beta) ** 2 * ratio_var / (diff**2)
return np.ceil(n_required)


# -------------------------- SAMPLE RATIO MISMATCH FUNCTION -------------------------- #


def sample_ratio_mismatch_test(
observed_counts: list[int] | np.ndarray,
expected_ratios: list[float] | np.ndarray | None = None,
alpha: float = 1e-3,
) -> SRMResult:
"""Sample Ratio Mismatch (SRM) via Pearson's chi-square goodness-of-fit.

Args:
observed_counts: 1-D array-like of non-negative integers per arm.
expected_ratios: 1-D array-like of expected proportions (sum ~ 1).
If None, assumes uniform split.
alpha: Significance level for flagging SRM.

Returns:
SRMResult(statistic, p_value, df, expected_counts, observed_counts, allocation, is_srm)
"""
obs = np.asarray(observed_counts, dtype=float)
if obs.ndim != 1:
raise ValueError("observed_counts must be 1-D")
if obs.size < 2:
raise ValueError("observed_counts must contain at least 2 groups")
if np.any(obs < 0):
raise ValueError("All observed counts must be non-negative")

N = obs.sum()
if not np.isfinite(N) or N <= 0:
raise ValueError("Total count must be positive")

k = obs.size
if expected_ratios is None:
alloc = np.full(k, 1.0 / k, dtype=float)
else:
alloc = np.asarray(expected_ratios, dtype=float)
if alloc.shape != (k,):
raise ValueError("expected_ratios must have the same length as observed_counts")
if np.any(~np.isfinite(alloc)) or np.any(alloc <= 0):
raise ValueError("All expected ratios must be finite and strictly positive")
s = float(alloc.sum())
if not np.isfinite(s) or s <= 0:
raise ValueError("expected_ratios must sum to a positive number")
if not np.isclose(s, 1.0, rtol=1e-6, atol=1e-12):
raise ValueError("expected_ratios must sum to 1 (within tolerance)")

expected = N * alloc
if np.any(expected <= 0):
raise ValueError("Expected counts contain zeros; ensure expected_ratios > 0 and N > 0")

# Pearson's chi-square
stat = np.sum((obs - expected) ** 2 / expected)
df = k - 1
p = chi2.sf(stat, df)

is_srm = bool(p < alpha)

return SRMResult(
statistic=float(stat),
p_value=float(p),
df=int(df),
expected=expected.astype(float),
observed=obs.astype(int),
allocation=alloc.astype(float),
is_srm=is_srm,
)
66 changes: 65 additions & 1 deletion src/ui/results/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,72 @@ def render(observation_cnt: dict[str, Any]) -> None:
Args:
observation_cnt: A dictionary with observation counts per group.
"""
import pandas as pd

from src.services.analytics.stat_functions import sample_ratio_mismatch_test

with st.expander("SRM Check", expanded=False):
st.markdown("place for srm")
placeholder = st.empty()
if not observation_cnt or len(observation_cnt) < 2:
st.warning("Need at least 2 groups to perform SRM check")
return

groups_from_page = observation_cnt.keys()
selected_groups = st.segmented_control(
label="Select groups",
options=groups_from_page,
key="srm_check_control",
default=groups_from_page,
selection_mode="multi",
)
if len(selected_groups) < 2:
st.warning("Need at least 2 groups to perform SRM check")
return

obs_items = [(k, int(v)) for k, v in observation_cnt.items() if k in selected_groups]
df = pd.DataFrame(obs_items, columns=["group", "counts"])
total_counts = df.counts.sum()
df["current_ratio"] = df.counts / total_counts
df["expected_ratio"] = 1 / df.shape[0]

col1, col2, col3 = st.columns([1, 1, 2])
expected_ratios = []
col1.markdown("**Group**")
col2.markdown("**Counts**")
col3.markdown("**Expected**")
for row in df.iterrows():
series = row[1]
col1, col2, col3 = st.columns([1, 1, 2])
col1.write(series.group)
col2.text(f"{int(series.counts)}")
with col3:
ratio_value = st.number_input(
"Ratio",
min_value=0.001,
max_value=1.0,
value=series.expected_ratio,
step=0.01,
key=f"srm_ratio_{series.group}",
label_visibility="collapsed",
)
expected_ratios.append(ratio_value)

if st.button("🔍 Check for SRM", type="primary"):
try:
result = sample_ratio_mismatch_test(
observed_counts=df.counts, expected_ratios=expected_ratios, alpha=1e-3
)
if result.is_srm:
placeholder.error(
f"Sample Ratio Mismatch detected! p-value: {result.p_value:.4f}", icon="🔥"
)
else:
placeholder.success(
f"No Sample Ratio Mismatch detected. p-value: {result.p_value:.4f}", icon="✅"
)

except Exception as e:
placeholder.warning(f"Error running SRM test: {e}")


class ResultsDataframes:
Expand Down
39 changes: 39 additions & 0 deletions tests/services/analytics/stat_functions/test_srm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from src.services.analytics.stat_functions import sample_ratio_mismatch_test


@pytest.mark.parametrize(
"observed_counts, expected_ratios, expected_is_srm, p_value_threshold",
[
([1000, 1010], None, False, 0.001),
([1000, 1200], None, True, 0.001),
([100, 205], [1 / 3, 2 / 3], False, 0.001),
([300, 700], [0.25, 0.75], True, 0.001),
],
)
def test_sample_ratio_mismatch_happy_path(
observed_counts, expected_ratios, expected_is_srm, p_value_threshold
):
result = sample_ratio_mismatch_test(observed_counts, expected_ratios, alpha=p_value_threshold)
assert result.is_srm is expected_is_srm
if expected_is_srm:
assert result.p_value < p_value_threshold
else:
assert result.p_value >= p_value_threshold


@pytest.mark.parametrize(
"observed_counts, expected_ratios, error_message",
[
([100], None, "observed_counts must contain at least 2 groups"),
([100, -10], None, "All observed counts must be non-negative"),
([100, 100], [0.5, 0.4, 0.1], "expected_ratios must have the same length as observed_counts"),
([100, 100], [0.6, 0.6], "expected_ratios must sum to 1"),
([0, 0], None, "Total count must be positive"),
],
)
def test_sample_ratio_mismatch_corner_cases(observed_counts, expected_ratios, error_message):
"""Tests corner cases and invalid inputs for sample_ratio_mismatch_test."""
with pytest.raises(ValueError, match=error_message):
sample_ratio_mismatch_test(observed_counts, expected_ratios)