Skip to content

Commit eb7680e

Browse files
cauchyturingclaude
andcommitted
feat(mcp): add feature_importance + graph_validation — 12 tools, true 100% coverage
Audit against Analysis.forward() revealed two missing capabilities: - Feature Importance (SHAP) — Analysis.feature_importance() dispatched by forward() - Graph Falsification — postprocess/judge.py via dowhy.gcm.falsify New tools: - compute_feature_importance: SHAP-based (linear/tree) feature importance - validate_graph: DoWhy GCM graph falsification (LMC testing) 62 tests, 12 tools. Every runnable capability in Causal-Copilot now MCP-native. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e9ab293 commit eb7680e

File tree

3 files changed

+398
-4
lines changed

3 files changed

+398
-4
lines changed

causal_copilot/mcp/estimation.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,109 @@ def run_distribution_change(
676676
}
677677

678678

679+
def compute_feature_importance(
680+
data: pd.DataFrame,
681+
target_node: str,
682+
is_linear: bool = True,
683+
) -> dict:
684+
"""Compute SHAP-based feature importance for a target variable.
685+
686+
Uses linear model SHAP for linear data, tree SHAP for nonlinear.
687+
Returns dict mapping feature names to mean absolute SHAP values.
688+
"""
689+
import shap
690+
from sklearn.linear_model import LinearRegression
691+
from sklearn.ensemble import RandomForestRegressor
692+
693+
X = data.drop(columns=[target_node])
694+
y = data[[target_node]]
695+
696+
if is_linear:
697+
model = LinearRegression()
698+
model.fit(X, y)
699+
background = shap.utils.sample(X, min(int(len(X) * 0.2), 100))
700+
explainer = shap.Explainer(model.predict, background)
701+
shap_values = explainer(X)
702+
else:
703+
model = RandomForestRegressor(n_estimators=100, random_state=42)
704+
model.fit(X, y.values.ravel())
705+
explainer = shap.TreeExplainer(model)
706+
shap_values = explainer(X)
707+
708+
shap_df = pd.DataFrame(np.abs(shap_values.values), columns=X.columns)
709+
mean_shap = shap_df.mean().sort_values(ascending=False)
710+
711+
return {
712+
"target_node": target_node,
713+
"method": "linear_shap" if is_linear else "tree_shap",
714+
"feature_importance": {
715+
col: _safe_float(val) for col, val in mean_shap.items()
716+
},
717+
"top_features": list(mean_shap.head(10).index),
718+
}
719+
720+
721+
def run_graph_falsification(
722+
data: pd.DataFrame,
723+
adj: np.ndarray,
724+
names: list[str],
725+
n_permutations: int = 20,
726+
) -> dict:
727+
"""Test if a causal graph is consistent with data via DoWhy GCM falsification.
728+
729+
Checks Local Markov Condition (LMC) violations. Returns test summary.
730+
"""
731+
import networkx as nx
732+
from dowhy.gcm.falsify import falsify_graph
733+
734+
G = nx.DiGraph()
735+
G.add_nodes_from(names)
736+
n = adj.shape[0]
737+
for i in range(n):
738+
for j in range(n):
739+
if adj[i, j] == 1:
740+
G.add_edge(names[j], names[i])
741+
742+
# Ensure DAG
743+
while not nx.is_directed_acyclic_graph(G):
744+
try:
745+
cycle = list(next(iter(nx.simple_cycles(G))))
746+
G.remove_edge(cycle[-1], cycle[0])
747+
except StopIteration:
748+
break
749+
750+
df = data[[c for c in names if c in data.columns]].copy()
751+
752+
result = falsify_graph(
753+
G, df,
754+
n_permutations=n_permutations,
755+
plot_histogram=False,
756+
suggestions=True,
757+
)
758+
759+
# Parse result string for structured output
760+
import re
761+
result_str = str(result)
762+
763+
# Extract key metrics from the result string
764+
violations = []
765+
suggestions_list = []
766+
767+
# Look for p-value and violation info
768+
p_value = None
769+
p_match = re.search(r'p_value\s*=?\s*([\d.]+)', result_str)
770+
if p_match:
771+
p_value = float(p_match.group(1))
772+
773+
return {
774+
"falsification_result": result_str,
775+
"p_value": _safe_float(p_value) if p_value else None,
776+
"n_permutations": n_permutations,
777+
"n_nodes": len(G.nodes),
778+
"n_edges": len(G.edges),
779+
}
780+
781+
679782
def run_intervention_simulation(
680783
data: pd.DataFrame,
681784
adj: np.ndarray,

causal_copilot/mcp/server.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
instructions="""\
4040
Causal discovery & inference expert — turns any dataset into a causal graph, estimates effects, and performs causal reasoning.
4141
42-
## Tools (10)
42+
## Tools (12)
4343
### Core Pipeline
4444
1. **discover** — autonomous pipeline. Handles everything: data diagnosis, algorithm
4545
selection, hyperparameter tuning, execution, postprocessing. Use for 90% of cases.
@@ -62,10 +62,18 @@
6262
10. **simulate_intervention** — interventional what-if: simulate shifting or setting a
6363
treatment value and see the outcome distribution change. Uses DoWhy GCM.
6464
65+
### Analysis & Validation
66+
11. **compute_feature_importance** — SHAP-based feature importance: which variables most
67+
influence a target? Uses linear or tree SHAP.
68+
12. **validate_graph** — graph falsification: test if the discovered causal graph is
69+
consistent with the data. Uses DoWhy GCM LMC testing.
70+
6571
## Workflow
6672
- Full: discover(csv) → inspect_graph(run_id, T, Y) → estimate_effect(run_id, T, Y)
6773
- Quick: discover(csv) → estimate_effect(run_id, T, Y)
6874
- Validate: estimate_effect(…) → refute_estimate(run_id, T, Y) for robustness
75+
- Graph check: discover(csv) → validate_graph(run_id) to test graph-data consistency
76+
- Feature drivers: discover(csv) → compute_feature_importance(run_id, target) for SHAP
6977
- What-if: discover(csv) → estimate_counterfactual(run_id, T, Y, value)
7078
- Root cause: discover(csv) → attribute_anomaly(run_id, target_node)
7179
- Expert: diagnose_data(csv) → run_algorithm(csv, algo) → estimate_effect(run_id, T, Y)
@@ -1764,6 +1772,159 @@ def simulate_intervention(
17641772
return json.dumps(results, indent=2, cls=_NumpyEncoder)
17651773

17661774

1775+
# ── Feature Importance ─────────────────────────────────────────────────
1776+
1777+
1778+
@mcp.tool()
1779+
def compute_feature_importance(
1780+
target_node: str,
1781+
run_id: str = "",
1782+
csv_data: str = "",
1783+
adjacency_matrix: str = "",
1784+
node_names: str = "",
1785+
data_diagnosis: str = "",
1786+
) -> str:
1787+
"""Compute SHAP-based feature importance for a target variable.
1788+
1789+
Shows which variables have the strongest predictive influence on the
1790+
target. Uses linear SHAP for linear data, tree SHAP for nonlinear.
1791+
1792+
Args:
1793+
target_node: Variable to analyze (must be in data columns)
1794+
run_id: Run ID from discover/run_algorithm
1795+
csv_data: CSV string (alternative to run_id)
1796+
adjacency_matrix: JSON 2D array (needed with csv_data)
1797+
node_names: JSON array of variable names (needed with csv_data)
1798+
data_diagnosis: JSON with linearity info (optional)
1799+
1800+
Returns:
1801+
JSON with feature importance scores sorted by magnitude
1802+
"""
1803+
df, adj, names, diagnosis = _resolve_data_and_graph(
1804+
run_id, csv_data, adjacency_matrix, node_names,
1805+
data_diagnosis=data_diagnosis,
1806+
)
1807+
1808+
if target_node not in df.columns:
1809+
raise ToolError(f"Target '{target_node}' not in data columns.")
1810+
1811+
is_linear = True
1812+
if diagnosis:
1813+
is_linear = diagnosis.get("linearity", True)
1814+
1815+
try:
1816+
from causal_copilot.mcp.estimation import compute_feature_importance as _compute_fi
1817+
1818+
with _pipeline_cwd():
1819+
results = _compute_fi(df, target_node, is_linear)
1820+
except Exception as e:
1821+
return json.dumps({
1822+
"status": "error",
1823+
"error": f"Feature importance failed: {e}",
1824+
"next_steps": ["Check data has enough observations and variance."],
1825+
})
1826+
1827+
results["status"] = "ok"
1828+
1829+
# Interpretation
1830+
top = results["top_features"][:3]
1831+
results["interpretation"] = (
1832+
f"Top drivers of {target_node}: {', '.join(top)}. "
1833+
f"Method: {results['method']}."
1834+
)
1835+
1836+
results["next_steps"] = [
1837+
f"estimate_effect(treatment='{top[0]}', outcome='{target_node}') "
1838+
"to quantify causal effect of the top feature" if top else "",
1839+
"inspect_graph() to see causal structure between these variables",
1840+
]
1841+
1842+
return json.dumps(results, indent=2, cls=_NumpyEncoder)
1843+
1844+
1845+
# ── Graph Validation ─────────────────────────────────────────────────
1846+
1847+
1848+
@mcp.tool()
1849+
def validate_graph(
1850+
run_id: str = "",
1851+
csv_data: str = "",
1852+
adjacency_matrix: str = "",
1853+
node_names: str = "",
1854+
n_permutations: int = 20,
1855+
) -> str:
1856+
"""Test if a causal graph is consistent with data (falsification).
1857+
1858+
Uses DoWhy GCM Local Markov Condition (LMC) testing. Compares the
1859+
proposed graph against random permutations. Low p-value means the
1860+
graph is significantly better than random.
1861+
1862+
Args:
1863+
run_id: Run ID from discover/run_algorithm
1864+
csv_data: CSV string (alternative to run_id)
1865+
adjacency_matrix: JSON 2D array (needed with csv_data)
1866+
node_names: JSON array of variable names (needed with csv_data)
1867+
n_permutations: Number of random graph permutations (default 20)
1868+
1869+
Returns:
1870+
JSON with falsification test results and interpretation
1871+
"""
1872+
df, adj, names, _ = _resolve_data_and_graph(
1873+
run_id, csv_data, adjacency_matrix, node_names,
1874+
)
1875+
1876+
graph_kind = classify_graph_kind(adj)
1877+
if graph_kind != "dag":
1878+
clean_adj, dropped = _sanitize_for_estimation(adj, names)
1879+
else:
1880+
clean_adj = adj
1881+
dropped = []
1882+
1883+
try:
1884+
from causal_copilot.mcp.estimation import run_graph_falsification
1885+
1886+
with _pipeline_cwd():
1887+
results = run_graph_falsification(
1888+
df, clean_adj, names, n_permutations,
1889+
)
1890+
except Exception as e:
1891+
return json.dumps({
1892+
"status": "error",
1893+
"error": f"Graph falsification failed: {e}",
1894+
"next_steps": ["Ensure graph is a DAG and data has enough observations."],
1895+
})
1896+
1897+
results["status"] = "ok"
1898+
results["graph_kind"] = graph_kind
1899+
if dropped:
1900+
results["dropped_edges"] = dropped
1901+
1902+
# Interpretation
1903+
p = results.get("p_value")
1904+
if p is not None:
1905+
if p < 0.05:
1906+
results["interpretation"] = (
1907+
f"Graph is significantly better than random (p={p:.4f}). "
1908+
"The causal structure appears consistent with the data."
1909+
)
1910+
else:
1911+
results["interpretation"] = (
1912+
f"Graph is NOT significantly better than random (p={p:.4f}). "
1913+
"The causal structure may not fit the data well."
1914+
)
1915+
else:
1916+
results["interpretation"] = (
1917+
"Falsification test completed. See falsification_result for details."
1918+
)
1919+
1920+
results["next_steps"] = [
1921+
"discover() to re-run causal discovery if graph doesn't fit",
1922+
"run_algorithm() with a different algorithm",
1923+
]
1924+
1925+
return json.dumps(results, indent=2, cls=_NumpyEncoder)
1926+
1927+
17671928
# ── MCP Resources ─────────────────────────────────────────────────────
17681929
from causal_copilot.mcp.resources import (
17691930
get_algorithm_resources,

0 commit comments

Comments
 (0)