|
39 | 39 | instructions="""\ |
40 | 40 | Causal discovery & inference expert — turns any dataset into a causal graph, estimates effects, and performs causal reasoning. |
41 | 41 |
|
42 | | -## Tools (10) |
| 42 | +## Tools (12) |
43 | 43 | ### Core Pipeline |
44 | 44 | 1. **discover** — autonomous pipeline. Handles everything: data diagnosis, algorithm |
45 | 45 | selection, hyperparameter tuning, execution, postprocessing. Use for 90% of cases. |
|
62 | 62 | 10. **simulate_intervention** — interventional what-if: simulate shifting or setting a |
63 | 63 | treatment value and see the outcome distribution change. Uses DoWhy GCM. |
64 | 64 |
|
| 65 | +### Analysis & Validation |
| 66 | +11. **compute_feature_importance** — SHAP-based feature importance: which variables most |
| 67 | + influence a target? Uses linear or tree SHAP. |
| 68 | +12. **validate_graph** — graph falsification: test if the discovered causal graph is |
| 69 | + consistent with the data. Uses DoWhy GCM LMC testing. |
| 70 | +
|
65 | 71 | ## Workflow |
66 | 72 | - Full: discover(csv) → inspect_graph(run_id, T, Y) → estimate_effect(run_id, T, Y) |
67 | 73 | - Quick: discover(csv) → estimate_effect(run_id, T, Y) |
68 | 74 | - Validate: estimate_effect(…) → refute_estimate(run_id, T, Y) for robustness |
| 75 | +- Graph check: discover(csv) → validate_graph(run_id) to test graph-data consistency |
| 76 | +- Feature drivers: discover(csv) → compute_feature_importance(run_id, target) for SHAP |
69 | 77 | - What-if: discover(csv) → estimate_counterfactual(run_id, T, Y, value) |
70 | 78 | - Root cause: discover(csv) → attribute_anomaly(run_id, target_node) |
71 | 79 | - Expert: diagnose_data(csv) → run_algorithm(csv, algo) → estimate_effect(run_id, T, Y) |
@@ -1764,6 +1772,159 @@ def simulate_intervention( |
1764 | 1772 | return json.dumps(results, indent=2, cls=_NumpyEncoder) |
1765 | 1773 |
|
1766 | 1774 |
|
| 1775 | +# ── Feature Importance ───────────────────────────────────────────────── |
| 1776 | + |
| 1777 | + |
| 1778 | +@mcp.tool() |
| 1779 | +def compute_feature_importance( |
| 1780 | + target_node: str, |
| 1781 | + run_id: str = "", |
| 1782 | + csv_data: str = "", |
| 1783 | + adjacency_matrix: str = "", |
| 1784 | + node_names: str = "", |
| 1785 | + data_diagnosis: str = "", |
| 1786 | +) -> str: |
| 1787 | + """Compute SHAP-based feature importance for a target variable. |
| 1788 | +
|
| 1789 | + Shows which variables have the strongest predictive influence on the |
| 1790 | + target. Uses linear SHAP for linear data, tree SHAP for nonlinear. |
| 1791 | +
|
| 1792 | + Args: |
| 1793 | + target_node: Variable to analyze (must be in data columns) |
| 1794 | + run_id: Run ID from discover/run_algorithm |
| 1795 | + csv_data: CSV string (alternative to run_id) |
| 1796 | + adjacency_matrix: JSON 2D array (needed with csv_data) |
| 1797 | + node_names: JSON array of variable names (needed with csv_data) |
| 1798 | + data_diagnosis: JSON with linearity info (optional) |
| 1799 | +
|
| 1800 | + Returns: |
| 1801 | + JSON with feature importance scores sorted by magnitude |
| 1802 | + """ |
| 1803 | + df, adj, names, diagnosis = _resolve_data_and_graph( |
| 1804 | + run_id, csv_data, adjacency_matrix, node_names, |
| 1805 | + data_diagnosis=data_diagnosis, |
| 1806 | + ) |
| 1807 | + |
| 1808 | + if target_node not in df.columns: |
| 1809 | + raise ToolError(f"Target '{target_node}' not in data columns.") |
| 1810 | + |
| 1811 | + is_linear = True |
| 1812 | + if diagnosis: |
| 1813 | + is_linear = diagnosis.get("linearity", True) |
| 1814 | + |
| 1815 | + try: |
| 1816 | + from causal_copilot.mcp.estimation import compute_feature_importance as _compute_fi |
| 1817 | + |
| 1818 | + with _pipeline_cwd(): |
| 1819 | + results = _compute_fi(df, target_node, is_linear) |
| 1820 | + except Exception as e: |
| 1821 | + return json.dumps({ |
| 1822 | + "status": "error", |
| 1823 | + "error": f"Feature importance failed: {e}", |
| 1824 | + "next_steps": ["Check data has enough observations and variance."], |
| 1825 | + }) |
| 1826 | + |
| 1827 | + results["status"] = "ok" |
| 1828 | + |
| 1829 | + # Interpretation |
| 1830 | + top = results["top_features"][:3] |
| 1831 | + results["interpretation"] = ( |
| 1832 | + f"Top drivers of {target_node}: {', '.join(top)}. " |
| 1833 | + f"Method: {results['method']}." |
| 1834 | + ) |
| 1835 | + |
| 1836 | + results["next_steps"] = [ |
| 1837 | + f"estimate_effect(treatment='{top[0]}', outcome='{target_node}') " |
| 1838 | + "to quantify causal effect of the top feature" if top else "", |
| 1839 | + "inspect_graph() to see causal structure between these variables", |
| 1840 | + ] |
| 1841 | + |
| 1842 | + return json.dumps(results, indent=2, cls=_NumpyEncoder) |
| 1843 | + |
| 1844 | + |
| 1845 | +# ── Graph Validation ───────────────────────────────────────────────── |
| 1846 | + |
| 1847 | + |
| 1848 | +@mcp.tool() |
| 1849 | +def validate_graph( |
| 1850 | + run_id: str = "", |
| 1851 | + csv_data: str = "", |
| 1852 | + adjacency_matrix: str = "", |
| 1853 | + node_names: str = "", |
| 1854 | + n_permutations: int = 20, |
| 1855 | +) -> str: |
| 1856 | + """Test if a causal graph is consistent with data (falsification). |
| 1857 | +
|
| 1858 | + Uses DoWhy GCM Local Markov Condition (LMC) testing. Compares the |
| 1859 | + proposed graph against random permutations. Low p-value means the |
| 1860 | + graph is significantly better than random. |
| 1861 | +
|
| 1862 | + Args: |
| 1863 | + run_id: Run ID from discover/run_algorithm |
| 1864 | + csv_data: CSV string (alternative to run_id) |
| 1865 | + adjacency_matrix: JSON 2D array (needed with csv_data) |
| 1866 | + node_names: JSON array of variable names (needed with csv_data) |
| 1867 | + n_permutations: Number of random graph permutations (default 20) |
| 1868 | +
|
| 1869 | + Returns: |
| 1870 | + JSON with falsification test results and interpretation |
| 1871 | + """ |
| 1872 | + df, adj, names, _ = _resolve_data_and_graph( |
| 1873 | + run_id, csv_data, adjacency_matrix, node_names, |
| 1874 | + ) |
| 1875 | + |
| 1876 | + graph_kind = classify_graph_kind(adj) |
| 1877 | + if graph_kind != "dag": |
| 1878 | + clean_adj, dropped = _sanitize_for_estimation(adj, names) |
| 1879 | + else: |
| 1880 | + clean_adj = adj |
| 1881 | + dropped = [] |
| 1882 | + |
| 1883 | + try: |
| 1884 | + from causal_copilot.mcp.estimation import run_graph_falsification |
| 1885 | + |
| 1886 | + with _pipeline_cwd(): |
| 1887 | + results = run_graph_falsification( |
| 1888 | + df, clean_adj, names, n_permutations, |
| 1889 | + ) |
| 1890 | + except Exception as e: |
| 1891 | + return json.dumps({ |
| 1892 | + "status": "error", |
| 1893 | + "error": f"Graph falsification failed: {e}", |
| 1894 | + "next_steps": ["Ensure graph is a DAG and data has enough observations."], |
| 1895 | + }) |
| 1896 | + |
| 1897 | + results["status"] = "ok" |
| 1898 | + results["graph_kind"] = graph_kind |
| 1899 | + if dropped: |
| 1900 | + results["dropped_edges"] = dropped |
| 1901 | + |
| 1902 | + # Interpretation |
| 1903 | + p = results.get("p_value") |
| 1904 | + if p is not None: |
| 1905 | + if p < 0.05: |
| 1906 | + results["interpretation"] = ( |
| 1907 | + f"Graph is significantly better than random (p={p:.4f}). " |
| 1908 | + "The causal structure appears consistent with the data." |
| 1909 | + ) |
| 1910 | + else: |
| 1911 | + results["interpretation"] = ( |
| 1912 | + f"Graph is NOT significantly better than random (p={p:.4f}). " |
| 1913 | + "The causal structure may not fit the data well." |
| 1914 | + ) |
| 1915 | + else: |
| 1916 | + results["interpretation"] = ( |
| 1917 | + "Falsification test completed. See falsification_result for details." |
| 1918 | + ) |
| 1919 | + |
| 1920 | + results["next_steps"] = [ |
| 1921 | + "discover() to re-run causal discovery if graph doesn't fit", |
| 1922 | + "run_algorithm() with a different algorithm", |
| 1923 | + ] |
| 1924 | + |
| 1925 | + return json.dumps(results, indent=2, cls=_NumpyEncoder) |
| 1926 | + |
| 1927 | + |
1767 | 1928 | # ── MCP Resources ───────────────────────────────────────────────────── |
1768 | 1929 | from causal_copilot.mcp.resources import ( |
1769 | 1930 | get_algorithm_resources, |
|
0 commit comments