Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 74 additions & 4 deletions src/services/analytics/stat_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Key features:
- Welch's t-test for comparing means with unequal variances.
- Z-test for comparing proportions.
- Ratio metric analysis using the delta method.
- Ratio metric anaysis using the delta method.
- Sample size calculations for various test types.
- Support for both scalar and vectorized operations.
"""
Expand All @@ -20,7 +20,7 @@
import numpy as np
from scipy.optimize import root_scalar # type: ignore
from scipy.special import ndtr, stdtr, stdtrit # type: ignore
from scipy.stats import norm, t # type: ignore
from scipy.stats import chi2, norm, t # type: ignore
from statsmodels.stats.power import TTestIndPower # type: ignore

from src.services.analytics._constants import NORM_PPF_ALPHA_TWO_SIDED, NORM_PPF_BETA
Expand All @@ -29,6 +29,9 @@
# Define types for result
ConfidenceInterval = namedtuple("ConfidenceInterval", ["lower", "upper"])
TestResult = namedtuple("TestResult", ["statistic", "p_value", "ci", "diff_abs", "diff_ratio"])
SRMResult = namedtuple(
"SRMResult", ["statistic", "p_value", "df", "expected", "observed", "allocation", "is_srm"]
)


def ttest_welch(
Expand Down Expand Up @@ -103,8 +106,7 @@ def ttest_welch(
ci_upper = t_diff + t_se * stat_ppf
ci = ConfidenceInterval(ci_lower, ci_upper)

# Handle division by zero for diff_ratio
diff_ratio = np.where(mean_1 != 0, t_diff / mean_1, np.inf)
diff_ratio = np.divide(t_diff, mean_1, out=np.full_like(t_diff, np.inf, dtype=float), where=mean_1 != 0)

return TestResult(statistic=t_stat, p_value=p_value, ci=ci, diff_abs=t_diff, diff_ratio=diff_ratio)

Expand Down Expand Up @@ -609,3 +611,71 @@ def sample_size_ratio_metric(

n_required = 2 * (z_alpha + z_beta) ** 2 * ratio_var / (diff**2)
return np.ceil(n_required)


# -------------------------- SAMPLE RATIO MISMATCH FUNCTION -------------------------- #


def sample_ratio_mismatch_test(
observed_counts: list[int] | np.ndarray,
expected_ratios: list[float] | np.ndarray | None = None,
alpha: float = 1e-3,
) -> SRMResult:
"""Sample Ratio Mismatch (SRM) via Pearson's chi-square goodness-of-fit.

Args:
observed_counts: 1-D array-like of non-negative integers per arm.
expected_ratios: 1-D array-like of expected proportions (sum ~ 1).
If None, assumes uniform split.
alpha: Significance level for flagging SRM.

Returns:
SRMResult(statistic, p_value, df, expected_counts, observed_counts, allocation, is_srm)
"""
obs = np.asarray(observed_counts, dtype=float)
if obs.ndim != 1:
raise ValueError("observed_counts must be 1-D")
if obs.size < 2:
raise ValueError("observed_counts must contain at least 2 groups")
if np.any(obs < 0):
raise ValueError("All observed counts must be non-negative")

N = obs.sum()
if not np.isfinite(N) or N <= 0:
raise ValueError("Total count must be positive")

k = obs.size
if expected_ratios is None:
alloc = np.full(k, 1.0 / k, dtype=float)
else:
alloc = np.asarray(expected_ratios, dtype=float)
if alloc.shape != (k,):
raise ValueError("expected_ratios must have the same length as observed_counts")
if np.any(~np.isfinite(alloc)) or np.any(alloc <= 0):
raise ValueError("All expected ratios must be finite and strictly positive")
s = float(alloc.sum())
if not np.isfinite(s) or s <= 0:
raise ValueError("expected_ratios must sum to a positive number")
if not np.isclose(s, 1.0, rtol=1e-6, atol=1e-12):
raise ValueError("expected_ratios must sum to 1 (within tolerance)")

expected = N * alloc
if np.any(expected <= 0):
raise ValueError("Expected counts contain zeros; ensure expected_ratios > 0 and N > 0")

# Pearson's chi-square
stat = np.sum((obs - expected) ** 2 / expected)
df = k - 1
p = chi2.sf(stat, df)

is_srm = bool(p < alpha)

return SRMResult(
statistic=float(stat),
p_value=float(p),
df=int(df),
expected=expected.astype(float),
observed=obs.astype(int),
allocation=alloc.astype(float),
is_srm=is_srm,
)
68 changes: 66 additions & 2 deletions src/ui/results/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,72 @@ def render(observation_cnt: dict[str, Any]) -> None:
Args:
observation_cnt: A dictionary with observation counts per group.
"""
import pandas as pd

from src.services.analytics.stat_functions import sample_ratio_mismatch_test

with st.expander("SRM Check", expanded=False):
st.markdown("place for srm")
placeholder = st.empty()
if not observation_cnt or len(observation_cnt) < 2:
st.warning("Need at least 2 groups to perform SRM check")
return

groups_from_page = observation_cnt.keys()
selected_groups = st.segmented_control(
label="Select groups",
options=groups_from_page,
key="srm_check_control",
default=groups_from_page,
selection_mode="multi",
)
if len(selected_groups) < 2:
st.warning("Need at least 2 groups to perform SRM check")
return

obs_items = [(k, int(v)) for k, v in observation_cnt.items() if k in selected_groups]
df = pd.DataFrame(obs_items, columns=["group", "counts"])
total_counts = df.counts.sum()
df["current_ratio"] = df.counts / total_counts
df["expected_ratio"] = 1 / df.shape[0]

col1, col2, col3 = st.columns([1, 1, 2])
expected_ratios = []
col1.markdown("**Group**")
col2.markdown("**Counts**")
col3.markdown("**Expected**")
for row in df.iterrows():
series = row[1]
col1, col2, col3 = st.columns([1, 1, 2])
col1.write(series.group)
col2.text(f"{int(series.counts)}")
with col3:
ratio_value = st.number_input(
"Ratio",
min_value=0.001,
max_value=1.0,
value=series.expected_ratio,
step=0.01,
key=f"srm_ratio_{series.group}",
label_visibility="collapsed",
)
expected_ratios.append(ratio_value)

if st.button("🔍 Check for SRM", type="primary"):
try:
result = sample_ratio_mismatch_test(
observed_counts=df.counts, expected_ratios=expected_ratios, alpha=1e-3
)
if result.is_srm:
placeholder.error(
f"Sample Ratio Mismatch detected! p-value: {result.p_value:.4f}", icon="🔥"
)
else:
placeholder.success(
f"No Sample Ratio Mismatch detected. p-value: {result.p_value:.4f}", icon="✅"
)

except Exception as e:
placeholder.warning(f"Error running SRM test: {e}")


class ResultsDataframes:
Expand Down Expand Up @@ -345,7 +409,7 @@ def render(cls, precomputes: pd.DataFrame | None) -> SettingsColumnLayout | None
p_value_threshold_filter = PValueThresholdFilter.render().threshold
metric_filters = MetricsFilters.render()

observations_cnt = precomputes.groupby("group_name")["observation_cnt"].unique().to_dict()
observations_cnt = precomputes.groupby("group_name")["observation_cnt"].max().astype(int).to_dict()
SampleRatioMismatchCheckExpander.render(observations_cnt)

return cls(
Expand Down
34 changes: 19 additions & 15 deletions tests/assistant/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -36,60 +37,63 @@ def user_payload() -> dict[str, Any]:
# ---------------------------------- Happy‑path test ----------------------------------


@patch("assistant.app.logger.instrument_requests", lambda *a, **k: None)
@patch("assistant.app.logger.instrument_sqlalchemy", lambda *a, **k: None)
@patch("assistant.app.init_assistant_service")
@patch("assistant.app.init_engine")
@patch("assistant.app.init_vdb")
def test_invoke_success(mock_vdb, mock_engine, mock_service, user_payload, dummy_usage, patch_configs):
"""/invoke returns AgentResponse and disposes engine on shutdown."""

fake_response = AssistantResponse(output="hi!", usage=dummy_usage, thinking="thinking")
# assistant_service is async, so we need an AsyncMock

# assistant_service — async
mock_service_instance = AsyncMock()
mock_service_instance.process_request.return_value = fake_response
mock_service.return_value = mock_service_instance

# Mock DB engine + vdb
mock_engine_instance = AsyncMock()
mock_engine_instance.dispose = AsyncMock()
mock_engine.return_value = mock_engine_instance
# Engine — объект с async dispose()
engine = MagicMock()
engine.dispose = AsyncMock()
mock_engine.return_value = engine

# vdb — любой sync мок
mock_vdb_instance = MagicMock()
mock_vdb.return_value = mock_vdb_instance

with TestClient(app) as client:
# --- request ---
resp = client.post("/invoke", json=user_payload)
assert resp.status_code == 200
assert resp.json()["output"] == "hi!"

# After exiting TestClient, shutdown event has run
mock_engine_instance.dispose.assert_awaited_once()
engine.dispose.assert_awaited_once()
mock_service_instance.process_request.assert_awaited_once()


# ------------------------ Validation error (missing chat_uid) ------------------------


@patch("assistant.app.logger.instrument_requests", lambda *a, **k: None)
@patch("assistant.app.logger.instrument_sqlalchemy", lambda *a, **k: None)
@patch("assistant.app.init_assistant_service", lambda *a, **kw: AsyncMock())
@patch("assistant.app.init_engine", lambda *a, **kw: AsyncMock())
@patch("assistant.app.init_engine", lambda *a, **kw: SimpleNamespace(dispose=AsyncMock()))
@patch("assistant.app.init_vdb", lambda *a, **kw: MagicMock())
def test_invoke_validation_error(user_payload, patch_configs):
"""Test validation error when chat_uid is missing."""
payload = user_payload.copy()
payload.pop("chat_uid")
with TestClient(app) as client:
resp = client.post("/invoke", json=payload)
assert resp.status_code == 422 # FastAPI validation error
assert resp.status_code == 422


#
# ---------------------------- Error path: assistant raises ----------------------------


@patch("assistant.app.logger.instrument_requests", lambda *a, **k: None)
@patch("assistant.app.logger.instrument_sqlalchemy", lambda *a, **k: None)
@patch("assistant.app.init_assistant_service")
@patch("assistant.app.init_engine", lambda *a, **kw: AsyncMock())
@patch("assistant.app.init_engine", lambda *a, **kw: SimpleNamespace(dispose=AsyncMock()))
@patch("assistant.app.init_vdb", lambda *a, **kw: MagicMock())
def test_invoke_service_error(mock_service, user_payload, patch_configs):
"""Test error handling when assistant service raises exception."""
err = RuntimeError("model crashed")
mock_service_instance = AsyncMock()
mock_service_instance.process_request.side_effect = err
Expand Down
39 changes: 39 additions & 0 deletions tests/services/analytics/stat_functions/test_srm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from src.services.analytics.stat_functions import sample_ratio_mismatch_test


@pytest.mark.parametrize(
"observed_counts, expected_ratios, expected_is_srm, p_value_threshold",
[
([1000, 1010], None, False, 0.001),
([1000, 1200], None, True, 0.001),
([100, 205], [1 / 3, 2 / 3], False, 0.001),
([300, 700], [0.25, 0.75], True, 0.001),
],
)
def test_sample_ratio_mismatch_happy_path(
observed_counts, expected_ratios, expected_is_srm, p_value_threshold
):
result = sample_ratio_mismatch_test(observed_counts, expected_ratios, alpha=p_value_threshold)
assert result.is_srm is expected_is_srm
if expected_is_srm:
assert result.p_value < p_value_threshold
else:
assert result.p_value >= p_value_threshold


@pytest.mark.parametrize(
"observed_counts, expected_ratios, error_message",
[
([100], None, "observed_counts must contain at least 2 groups"),
([100, -10], None, "All observed counts must be non-negative"),
([100, 100], [0.5, 0.4, 0.1], "expected_ratios must have the same length as observed_counts"),
([100, 100], [0.6, 0.6], "expected_ratios must sum to 1"),
([0, 0], None, "Total count must be positive"),
],
)
def test_sample_ratio_mismatch_corner_cases(observed_counts, expected_ratios, error_message):
"""Tests corner cases and invalid inputs for sample_ratio_mismatch_test."""
with pytest.raises(ValueError, match=error_message):
sample_ratio_mismatch_test(observed_counts, expected_ratios)
19 changes: 19 additions & 0 deletions tests/services/analytics/stat_functions/test_ttest_welch.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,22 @@ def test_sample_size_t_test_edge_cases():
alpha=0.05,
beta=0.2,
)


def test_ttest_welch_zero_division_no_warning(recwarn):
"""Tests that no RuntimeWarning is raised for 0/0 division with np.divide."""
mean_1 = np.array([10, 0])
var_1 = np.array([1, 1])
n_1 = np.array([100, 100])
mean_2 = np.array([12, 0])
var_2 = np.array([1, 1])
n_2 = np.array([100, 100])

result = ttest_welch(mean_1, var_1, n_1, mean_2, var_2, n_2)

# Check that no warnings were issued
assert len(recwarn) == 0

# The result for the 0/0 case should be np.inf with the np.divide fix
expected_diff_ratio = np.array([0.2, np.inf])
assert np.allclose(result.diff_ratio, expected_diff_ratio, equal_nan=True)