Skip to content

Commit d075b58

Browse files
cauchyturingclaude
andcommitted
fix: audit round 5 — 3 bugs (edge cases + missed copilot.py int() cast)
B1: prepare_treatment raises ValueError for single-value treatment (was IndexError on unique_vals[1] when nunique < 2) B2: copilot.py matching removed int() cast on control/treatment values (Round 3 B3 only fixed server.py, missed copilot.py) B3: server.py T0/T1 parsing handles categorical string values (float("A") crashed with ValueError for categorical treatments) Audit confirmed: 0 GlobalState field gaps, 0 scientific rigor issues, full main.py parity. 134 tests pass (3 new). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cf10aa3 commit d075b58

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

causal_copilot/copilot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,8 @@ def estimate_effect(
820820
treatment,
821821
outcome,
822822
match_conf,
823-
int(control_value),
824-
int(treatment_value),
823+
control_value,
824+
treatment_value,
825825
)
826826
elif selected_method == "dml":
827827
X_col = [c for c in names if c != treatment and c != outcome and c not in conf_list]

causal_copilot/mcp/offline.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def prepare_treatment(
7777
"""
7878
treatment_col = data[treatment]
7979

80+
if treatment_col.nunique() < 2:
81+
raise ValueError(
82+
f"Treatment '{treatment}' has {treatment_col.nunique()} unique value(s). "
83+
"Need at least 2 distinct values for causal effect estimation."
84+
)
85+
8086
# Case 1: String / object / category
8187
if treatment_col.dtype == "object" or treatment_col.dtype.name == "category":
8288
unique_vals = sorted(treatment_col.unique().tolist())

causal_copilot/mcp/server.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,17 @@ def estimate_effect(
490490

491491
# Parse user-provided T0/T1 or leave as None for auto-detection.
492492
# Auto-detect: binary/discrete → min/max, continuous → 10th/90th percentile.
493-
T0_input = float(control_value) if control_value else None
494-
T1_input = float(treatment_value) if treatment_value else None
493+
# Use try/except for float() to support categorical string values (e.g., "A", "B").
494+
def _parse_tv(val):
495+
if not val:
496+
return None
497+
try:
498+
return float(val)
499+
except (ValueError, TypeError):
500+
return val # categorical string
501+
502+
T0_input = _parse_tv(control_value)
503+
T1_input = _parse_tv(treatment_value)
495504
_, T0_computed, T1_computed, treatment_kind = prepare_treatment(df, treatment, T0=T0_input, T1=T1_input)
496505
control_value = T0_computed
497506
treatment_value = T1_computed

tests/test_mcp.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,6 +2653,42 @@ def test_b8_copilot_extracts_lagged_graph(self):
26532653
assert "lagged_graph" in source
26542654

26552655

2656+
# ── Audit Round 5 — Bug Fixes ──────────────────────────────────────────
2657+
2658+
2659+
class TestAuditRound5Bugs:
2660+
"""Tests for bugs found in the fifth comprehensive audit."""
2661+
2662+
def test_b1_prepare_treatment_single_value_raises(self):
2663+
"""B1: prepare_treatment must raise ValueError for single-value treatment."""
2664+
from causal_copilot.mcp.offline import prepare_treatment
2665+
2666+
df = pd.DataFrame({"T": [1, 1, 1, 1], "Y": [2, 3, 4, 5]})
2667+
with pytest.raises(ValueError, match="unique value"):
2668+
prepare_treatment(df, "T")
2669+
2670+
def test_b2_copilot_matching_no_int_cast(self):
2671+
"""B2: copilot.py matching must NOT int()-cast control/treatment values."""
2672+
import inspect
2673+
2674+
from causal_copilot.copilot import CausalCopilot
2675+
2676+
source = inspect.getsource(CausalCopilot)
2677+
# The estimate_matching call should NOT have int() wrapping
2678+
assert "int(control_value)" not in source
2679+
assert "int(treatment_value)" not in source
2680+
2681+
def test_b3_server_categorical_t0_t1(self):
2682+
"""B3: server.py T0/T1 parsing handles categorical string values."""
2683+
import inspect
2684+
2685+
from causal_copilot.mcp.server import estimate_effect
2686+
2687+
source = inspect.getsource(estimate_effect)
2688+
# Must have try/except or other handling for non-numeric T0/T1
2689+
assert "_parse_tv" in source or "except" in source
2690+
2691+
26562692
# ── MCP CLI ────────────────────────────────────────────────────────────
26572693

26582694

0 commit comments

Comments
 (0)