Skip to content

Commit cf10aa3

Browse files
cauchyturingclaude
andcommitted
fix: audit round 4 — 8 bugs for main.py parity + scientific rigor
B1: T0/T1 defaults changed from float 0.0/1.0 to str "" (auto-detect: binary/discrete→min/max, continuous→10th/90th percentile) B2: EDA fallback dict now complete — all 8 keys accessed by report_generation.py (dist_analysis_num/cat, corr_analysis, plot_path_lag_corr, lag_corr_summary as dict, diagnostics_summary) B3: discover() resolver overrides now conditional ("indep_test" in algo_args) matching hyperparameter_selector.py:38 B4: domain_index/heterogeneous detection in make_global_state (prevents domain_index as causal variable, enables CDNOD) B5: DRL always discretizes non-binary treatment (matching inference.py:901) B6: diagnose_data + run_algorithm call convert_stat_info_to_text B7: T→O edge restored after sanitization for CPDAG estimation B8: copilot.py extracts lagged_graph from metadata (fixes dead TS Judge skip) Also: refute_estimate T0/T1 auto-detection, test mock updated for convert_stat_info_to_text. 131 tests pass (10 new). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent efd1726 commit cf10aa3

File tree

5 files changed

+314
-55
lines changed

5 files changed

+314
-55
lines changed

causal_copilot/copilot.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,20 @@ def analyze(
509509
gs.user_data.processed_data = numeric_df
510510
gs.user_data.selected_features = list(numeric_df.columns)
511511

512+
# Extract lagged_graph from metadata for time-series algos.
513+
# Programming.forward() (program.py:58-78) does this from
514+
# info['lag_matrix']. Without it, TS Judge skip is dead code.
515+
is_ts = getattr(gs.statistics, "time_series", False)
516+
if is_ts and isinstance(metadata, dict) and "lag_matrix" in metadata:
517+
lag = metadata["lag_matrix"]
518+
if isinstance(lag, list):
519+
lag = np.array(lag)
520+
gs.results.lagged_graph = lag
521+
elif is_ts:
522+
gs.results.lagged_graph = None
523+
512524
# Postprocess: skip Judge for time-series data when lagged_graph
513525
# exists (main.py:268: time_series AND lagged_graph is not None)
514-
is_ts = getattr(gs.statistics, "time_series", False)
515526
has_lagged = getattr(gs.results, "lagged_graph", None) is not None
516527
if is_ts and has_lagged:
517528
gs.results.revised_graph = gs.results.converted_graph
@@ -1450,15 +1461,21 @@ def generate_report(
14501461
eda.generate_eda()
14511462
except Exception as eda_err:
14521463
report_warnings.append(f"EDA generation skipped: {eda_err}")
1453-
# Set minimal eda with required keys to prevent KeyError in
1454-
# report_generation.py:386 eda_prompt() accessing plot_path_dist/corr
1455-
# and ts_eda_prompt():339-356 accessing lag_corr_summary/diagnostics_summary
1464+
# Set minimal EDA with ALL keys accessed by report_generation.py:
1465+
# Non-TS: eda_summary_to_latex() → dist_analysis_num, dist_analysis_cat,
1466+
# corr_analysis; eda_prompt() → plot_path_dist, plot_path_corr
1467+
# TS: ts_eda_prompt() → lag_corr_summary (dict w/ potential_granger_causality),
1468+
# plot_path_lag_corr, diagnostics_summary (dict)
14561469
if not hasattr(gs.results, "eda") or gs.results.eda is None or not gs.results.eda:
14571470
gs.results.eda = {
14581471
"plot_path_dist": [""],
14591472
"plot_path_corr": [""],
1460-
"lag_corr_summary": "",
1461-
"diagnostics_summary": "",
1473+
"plot_path_lag_corr": "",
1474+
"dist_analysis_num": {},
1475+
"dist_analysis_cat": {},
1476+
"corr_analysis": {},
1477+
"lag_corr_summary": {"potential_granger_causality": []},
1478+
"diagnostics_summary": {},
14621479
}
14631480

14641481
# 2. Visualizations

causal_copilot/mcp/bridge.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ def make_global_state(df, query="", algorithm=None, seed=42):
6464
gs.user_data.output_report_dir = output_dir
6565
gs.user_data.output_graph_dir = output_dir
6666

67+
# domain_index detection — matching Initialize_state.py:44-49.
68+
# Without this, domain_index column is treated as a causal variable
69+
# and CDNOD (heterogeneous data algo) won't be triggered.
70+
if "domain_index" in df.columns:
71+
if df["domain_index"].nunique() > 1:
72+
gs.statistics.heterogeneous = True
73+
else:
74+
gs.statistics.heterogeneous = False
75+
gs.statistics.domain_index = "domain_index"
76+
6777
if algorithm:
6878
gs.algorithm.selected_algorithm = algorithm
6979

causal_copilot/mcp/estimation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,17 @@ def estimate_drl(
261261
df["_W_dummy"] = 0.0
262262
actual_W = ["_W_dummy"]
263263

264-
# DRL requires discrete treatment — discretize continuous treatment into
265-
# quantile bins, matching inference.py prepare_treatment_column(discretize=True).
266-
# The Analysis class always sets discretize=True for all DRL variants (line 901).
264+
# DRL requires discrete treatment — always discretize, matching
265+
# inference.py line 901 which sets discretize=True for ALL DRL variants.
266+
# pd.qcut on binary data will raise ValueError and fall back gracefully.
267267
T_series = df[treatment]
268-
if treatment_kind == "continuous" or (pd.api.types.is_numeric_dtype(T_series) and T_series.nunique() > 10):
268+
if treatment_kind != "binary":
269269
try:
270270
df[treatment] = pd.qcut(T_series, q=3, labels=[0, 1, 2])
271271
unique_vals = sorted(df[treatment].unique())
272272
T0, T1 = unique_vals[0], unique_vals[-1]
273273
except ValueError:
274-
pass # qcut fails on low-variance data; use raw treatment
274+
pass # qcut fails on low-variance/binary data; use raw treatment
275275

276276
Y = df[outcome].values
277277
T = df[treatment].values

causal_copilot/mcp/server.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def estimate_effect(
415415
adjacency_matrix: str = "",
416416
node_names: str = "",
417417
method: str = "",
418-
control_value: float = 0.0,
419-
treatment_value: float = 1.0,
418+
control_value: str = "",
419+
treatment_value: str = "",
420420
confounders: str = "",
421421
data_diagnosis: str = "",
422422
instrument: str = "",
@@ -439,8 +439,10 @@ def estimate_effect(
439439
node_names: JSON array of variable names
440440
method: Estimation method ("linear", "matching", "dml", "drl",
441441
"metalearner", "iv", or "" for auto)
442-
control_value: Reference value for control group (default 0.0)
443-
treatment_value: Reference value for treatment group (default 1.0)
442+
control_value: Control group reference value (empty for auto-detect:
443+
binary/discrete → min, continuous → 10th percentile)
444+
treatment_value: Treatment group reference value (empty for auto-detect:
445+
binary/discrete → max, continuous → 90th percentile)
444446
confounders: JSON array of confounder names (default: auto-detect from graph)
445447
data_diagnosis: JSON with linearity/gaussian_error (needed for CPDAG)
446448
instrument: Instrument variable name for IV method (auto-detected from graph if empty)
@@ -486,10 +488,11 @@ def estimate_effect(
486488
select_estimation_method as _offline_select_method,
487489
)
488490

489-
# Compute T0/T1 from data — for continuous treatment, prepare_treatment
490-
# returns 10th/90th percentile values. Feed these back into control_value/
491-
# treatment_value for all estimation calls (matching copilot.py:689-696).
492-
_, T0_computed, T1_computed, treatment_kind = prepare_treatment(df, treatment, T0=control_value, T1=treatment_value)
491+
# Parse user-provided T0/T1 or leave as None for auto-detection.
492+
# Auto-detect: binary/discrete → min/max, continuous → 10th/90th percentile.
493+
T0_input = float(control_value) if control_value else None
494+
T1_input = float(treatment_value) if treatment_value else None
495+
_, T0_computed, T1_computed, treatment_kind = prepare_treatment(df, treatment, T0=T0_input, T1=T1_input)
493496
control_value = T0_computed
494497
treatment_value = T1_computed
495498

@@ -595,9 +598,15 @@ def estimate_effect(
595598
+ ("..." if len(dropped_edges) > 5 else "")
596599
)
597600

598-
# --- Confounders ---
601+
# --- Restore T→O edge if it was undirected (CPDAG) and got dropped ---
599602
t_idx = names.index(treatment)
600603
o_idx = names.index(outcome)
604+
had_edge = adj[o_idx, t_idx] != 0 or adj[t_idx, o_idx] != 0
605+
if had_edge and clean_adj[o_idx, t_idx] == 0:
606+
clean_adj[o_idx, t_idx] = 1 # Restore as directed T→O
607+
warnings_list.append(f"Restored {treatment}->{outcome} as directed for estimation (was undirected in CPDAG)")
608+
609+
# --- Confounders ---
601610
if confounders:
602611
try:
603612
conf_list = json.loads(confounders)
@@ -878,11 +887,15 @@ def diagnose_data(csv_data: str) -> str:
878887
raise ToolError("Need at least 2 columns of data.")
879888

880889
try:
881-
from preprocess.stat_info_functions import stat_info_collection
890+
from preprocess.stat_info_functions import (
891+
convert_stat_info_to_text,
892+
stat_info_collection,
893+
)
882894

883895
gs = make_global_state(df)
884896
with _pipeline_cwd():
885897
gs = stat_info_collection(gs)
898+
gs.statistics.description = convert_stat_info_to_text(gs.statistics)
886899

887900
stats = gs.statistics
888901

@@ -1052,13 +1065,17 @@ def run_algorithm(
10521065
from causal_discovery.ci_test_resolver import resolve_ci_test
10531066
from causal_discovery.program import Programming
10541067
from causal_discovery.score_resolver import resolve_score_func
1055-
from preprocess.stat_info_functions import stat_info_collection
1068+
from preprocess.stat_info_functions import (
1069+
convert_stat_info_to_text,
1070+
stat_info_collection,
1071+
)
10561072

10571073
gs = make_global_state(df, algorithm=algorithm, seed=seed)
10581074
args = make_args(seed=seed)
10591075

10601076
with _pipeline_cwd():
10611077
gs = stat_info_collection(gs)
1078+
gs.statistics.description = convert_stat_info_to_text(gs.statistics)
10621079

10631080
# Start from user's exact hyperparameters
10641081
requested_hp = dict(hp)
@@ -1323,11 +1340,11 @@ def discover(
13231340
"IAMBnPC",
13241341
"MBOR",
13251342
}
1326-
if algo_name in ci_test_algos:
1343+
if algo_name in ci_test_algos and "indep_test" in algo_args:
13271344
algo_args["indep_test"] = resolve_ci_test(gs.statistics)
13281345

13291346
score_algos = {"GES", "FGES", "XGES", "GRaSP", "ExactSearch", "BOSS"}
1330-
if algo_name in score_algos:
1347+
if algo_name in score_algos and "score_func" in algo_args:
13311348
algo_args["score_func"] = resolve_score_func(
13321349
gs.statistics,
13331350
algo_name,
@@ -1701,8 +1718,8 @@ def refute_estimate(
17011718
csv_data: str = "",
17021719
adjacency_matrix: str = "",
17031720
node_names: str = "",
1704-
control_value: float = 0.0,
1705-
treatment_value: float = 1.0,
1721+
control_value: str = "",
1722+
treatment_value: str = "",
17061723
) -> str:
17071724
"""Test robustness of a causal effect estimate with sensitivity analysis.
17081725
@@ -1722,8 +1739,8 @@ def refute_estimate(
17221739
csv_data: CSV string (alternative to run_id)
17231740
adjacency_matrix: JSON 2D array (needed with csv_data)
17241741
node_names: JSON array of variable names (needed with csv_data)
1725-
control_value: Control group value (default 0.0)
1726-
treatment_value: Treatment group value (default 1.0)
1742+
control_value: Control group reference value (empty for auto-detect)
1743+
treatment_value: Treatment group reference value (empty for auto-detect)
17271744
17281745
Returns:
17291746
JSON with original_estimate, refutation results, interpretation
@@ -1751,6 +1768,15 @@ def refute_estimate(
17511768
}
17521769
)
17531770

1771+
# Parse T0/T1 with auto-detection
1772+
from causal_copilot.mcp.offline import prepare_treatment
1773+
1774+
T0_input = float(control_value) if control_value else None
1775+
T1_input = float(treatment_value) if treatment_value else None
1776+
_, T0_val, T1_val, _ = prepare_treatment(df, treatment, T0=T0_input, T1=T1_input)
1777+
control_value = T0_val
1778+
treatment_value = T1_val
1779+
17541780
clean_adj, _ = _sanitize_for_estimation(adj, names)
17551781
dot_graph = _adj_to_dot(clean_adj, names)
17561782

@@ -2501,15 +2527,21 @@ def generate_report(run_id: str) -> str:
25012527
eda.generate_eda()
25022528
except Exception as eda_err:
25032529
report_warnings.append(f"EDA generation skipped: {eda_err}")
2504-
# Set minimal eda with required keys to prevent KeyError in
2505-
# report_generation.py:386 eda_prompt() accessing plot_path_dist/corr
2506-
# and ts_eda_prompt():339-356 accessing lag_corr_summary/diagnostics_summary
2530+
# Set minimal EDA with ALL keys accessed by report_generation.py:
2531+
# Non-TS: eda_summary_to_latex() → dist_analysis_num, dist_analysis_cat,
2532+
# corr_analysis; eda_prompt() → plot_path_dist, plot_path_corr
2533+
# TS: ts_eda_prompt() → lag_corr_summary (dict w/ potential_granger_causality),
2534+
# plot_path_lag_corr, diagnostics_summary (dict)
25072535
if not hasattr(gs.results, "eda") or gs.results.eda is None or not gs.results.eda:
25082536
gs.results.eda = {
25092537
"plot_path_dist": [""],
25102538
"plot_path_corr": [""],
2511-
"lag_corr_summary": "",
2512-
"diagnostics_summary": "",
2539+
"plot_path_lag_corr": "",
2540+
"dist_analysis_num": {},
2541+
"dist_analysis_cat": {},
2542+
"corr_analysis": {},
2543+
"lag_corr_summary": {"potential_granger_causality": []},
2544+
"diagnostics_summary": {},
25132545
}
25142546

25152547
# 2. Visualizations — graph plots, heatmaps

0 commit comments

Comments
 (0)