Skip to content

Commit 0283d8d

Browse files
cauchyturingclaude
andcommitted
feat: achieve 1:1 scientific rigor parity with main.py across all layers
Inference (MCP + Python API): - Replace weak MCP confounder detection (adj==1 only) with offline.identify_confounders (1/3/4 confirmed, 2 potential) - Replace naive MCP auto-select (binary→matching) with offline.select_estimation_method (full decision tree) - Pass is_linear/treatment_kind to DML/DRL for proper variant selection (LinearDML/SparseLinearDML/CausalForestDML) - Select MetaLearner variant based on linearity (TLearner for linear, XLearner for nonlinear) - Add SHAP-benchmarked partial-R2 sensitivity analysis to MCP refute_estimate - Add treatment_kind and algo to MCP result provenance Discovery (Python API): - Replace dropna() with stat_info_collection() (MICE imputation, label encoding, z-score normalization) - Replace heuristic stat tests with Ramsey RESET (linearity) and Shapiro-Wilk (gaussianity) - Add LLM Filter+Reranker algorithm selection with rule-based fallback - Add LLM HyperparameterSelector with defaults fallback - Add Judge postprocessing (bootstrap stability + KCI pruning + LLM refinement) - Store original-scale data for estimation (not normalized), matching main.py Analysis class Discovery (MCP): - Remove TS guard on Judge postprocessing — now runs for ALL data including time-series All changes fall back gracefully if pipeline modules unavailable. 422 passed, 10 skipped, 0 failed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5af2d98 commit 0283d8d

File tree

7 files changed

+1993
-112
lines changed

7 files changed

+1993
-112
lines changed

causal_copilot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from causal_copilot.copilot import CausalCopilot as CausalCopilot
66
from causal_copilot.core.result import CausalResult as CausalResult
77
from causal_copilot.core.result import Provenance as Provenance
8+
from causal_copilot.core.result import TreatmentEffect as TreatmentEffect

causal_copilot/cli.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def cmd_version(args):
9292

9393

9494
def cmd_analyze(args):
95-
"""Run causal discovery on a CSV file."""
95+
"""Run causal discovery on a CSV file, optionally followed by effect estimation."""
9696
from causal_copilot import CausalCopilot
9797

9898
try:
@@ -108,12 +108,30 @@ def cmd_analyze(args):
108108
seed=args.seed,
109109
)
110110

111+
# Run estimation if treatment/outcome specified
112+
if result.status == "ok" and args.treatment and args.outcome:
113+
try:
114+
result = copilot.estimate_effect(
115+
result,
116+
treatment=args.treatment,
117+
outcome=args.outcome,
118+
method=args.method,
119+
)
120+
except (ValueError, ImportError) as e:
121+
print(f"Estimation error: {e}", file=sys.stderr)
122+
111123
if args.output:
112124
out_path = Path(args.output)
113125
out_path.write_text(json.dumps(result.to_dict(), indent=2))
114126
print(f"Result written to {out_path}")
115127
else:
116128
print(result.summary)
129+
if result.effects:
130+
print("\nEffects:")
131+
for key, eff in result.effects.items():
132+
print(f" {key}: ATE={eff.ate}, method={eff.method}")
133+
if eff.ate_ci:
134+
print(f" 95% CI: [{eff.ate_ci[0]:.4f}, {eff.ate_ci[1]:.4f}]")
117135
if result.warnings:
118136
print(f"\nWarnings ({len(result.warnings)}):")
119137
for w in result.warnings:
@@ -307,6 +325,11 @@ def main(argv=None):
307325
p_analyze.add_argument("--planner", "-p", default="rule", help="Planner: rule (default)")
308326
p_analyze.add_argument("--timeout", "-t", type=int, default=300, help="Timeout in seconds (default: 300)")
309327
p_analyze.add_argument("--seed", "-s", type=int, default=42, help="Random seed (default: 42)")
328+
p_analyze.add_argument("--treatment", "-T", help="Treatment variable (enables effect estimation)")
329+
p_analyze.add_argument("--outcome", "-Y", help="Outcome variable (enables effect estimation)")
330+
p_analyze.add_argument(
331+
"--method", "-m", help="Estimation method: linear, matching, dml, drl, metalearner, iv (auto if omitted)"
332+
)
310333

311334
# benchmark
312335
p_bench = sub.add_parser("benchmark", help="Run benchmark evaluation")

0 commit comments

Comments
 (0)