@@ -660,7 +660,10 @@ def estimate_effect(
660660
661661 # --- Treatment type dispatch (4 cases, matching original) ---
662662 _ , T0 , T1 , treatment_kind = prepare_treatment (
663- df , treatment , T0 = control_value , T1 = treatment_value ,
663+ df ,
664+ treatment ,
665+ T0 = control_value ,
666+ T1 = treatment_value ,
664667 )
665668 control_value = T0
666669 treatment_value = T1
@@ -674,9 +677,7 @@ def estimate_effect(
674677 is_lg = is_linear and is_gaussian
675678 policy = check_inference_policy (adj , is_linear_gaussian = is_lg )
676679 if not policy ["allow_inference" ]:
677- result .warnings = warnings_list + [
678- f"Inference rejected: { policy ['reason' ]} "
679- ]
680+ result .warnings = warnings_list + [f"Inference rejected: { policy ['reason' ]} " ]
680681 result .summary += f" Inference rejected: { policy ['reason' ]} "
681682 return result
682683 warnings_list .append ("CPDAG: undirected edges dropped for estimation" )
@@ -701,8 +702,7 @@ def estimate_effect(
701702 if had_edge and clean_adj [o_idx , t_idx ] == 0 :
702703 clean_adj [o_idx , t_idx ] = 1
703704 warnings_list .append (
704- f"Restored { treatment } ->{ outcome } as directed for estimation "
705- f"(was undirected in CPDAG)"
705+ f"Restored { treatment } ->{ outcome } as directed for estimation (was undirected in CPDAG)"
706706 )
707707
708708 # --- Determine confounders (full adj matrix type check) ---
@@ -711,14 +711,15 @@ def estimate_effect(
711711 potential_conf = []
712712 else :
713713 conf_list , potential_conf = identify_confounders (
714- adj , names , treatment , outcome ,
714+ adj ,
715+ names ,
716+ treatment ,
717+ outcome ,
715718 )
716719 if potential_conf and not conf_list :
717720 # Use potential confounders when no confirmed ones exist
718721 conf_list = potential_conf
719- warnings_list .append (
720- f"Using potential confounders from undirected edges: { conf_list } "
721- )
722+ warnings_list .append (f"Using potential confounders from undirected edges: { conf_list } " )
722723
723724 # --- Check for IV ---
724725 has_iv = bool (instrument )
@@ -730,9 +731,7 @@ def estimate_effect(
730731 continue
731732 if clean_adj [o_idx , z_idx ] == 1 :
732733 continue
733- has_parents = any (
734- clean_adj [z_idx , j ] == 1 for j in range (n ) if j != z_idx
735- )
734+ has_parents = any (clean_adj [z_idx , j ] == 1 for j in range (n ) if j != z_idx )
736735 if not has_parents :
737736 has_iv = True
738737 if not instrument :
@@ -742,7 +741,9 @@ def estimate_effect(
742741 # --- Auto-select method (data-driven, replaces LLM Filter) ---
743742 if method is None :
744743 selected_method = select_estimation_method (
745- df , treatment , treatment_kind ,
744+ df ,
745+ treatment ,
746+ treatment_kind ,
746747 is_linear = is_linear ,
747748 is_gaussian = is_gaussian ,
748749 n_features = len (names ) - 1 ,
@@ -765,40 +766,52 @@ def estimate_effect(
765766 if selected_method == "linear" :
766767 dot_graph = _adj_to_dot (clean_adj , names )
767768 estimates = estimate_linear (
768- df , dot_graph , treatment , outcome ,
769- control_value , treatment_value ,
769+ df ,
770+ dot_graph ,
771+ treatment ,
772+ outcome ,
773+ control_value ,
774+ treatment_value ,
770775 )
771776 elif selected_method == "matching" :
772- match_conf = conf_list if conf_list else [
773- c for c in names if c != treatment and c != outcome
774- ]
777+ match_conf = conf_list if conf_list else [c for c in names if c != treatment and c != outcome ]
775778 estimates = estimate_matching (
776- df , treatment , outcome , match_conf ,
777- int (control_value ), int (treatment_value ),
779+ df ,
780+ treatment ,
781+ outcome ,
782+ match_conf ,
783+ int (control_value ),
784+ int (treatment_value ),
778785 )
779786 elif selected_method == "dml" :
780787 X_col = [c for c in names if c != treatment and c != outcome and c not in conf_list ]
781788 if not X_col :
782- X_col = conf_list [:] if conf_list else [
783- c for c in names if c != treatment and c != outcome
784- ]
789+ X_col = conf_list [:] if conf_list else [c for c in names if c != treatment and c != outcome ]
785790 W_col = conf_list if conf_list else []
786791 estimates = estimate_dml (
787- df , treatment , outcome , X_col , W_col ,
788- control_value , treatment_value ,
792+ df ,
793+ treatment ,
794+ outcome ,
795+ X_col ,
796+ W_col ,
797+ control_value ,
798+ treatment_value ,
789799 is_linear = is_linear ,
790800 treatment_kind = treatment_kind ,
791801 )
792802 elif selected_method == "drl" :
793803 X_col = [c for c in names if c != treatment and c != outcome and c not in conf_list ]
794804 if not X_col :
795- X_col = conf_list [:] if conf_list else [
796- c for c in names if c != treatment and c != outcome
797- ]
805+ X_col = conf_list [:] if conf_list else [c for c in names if c != treatment and c != outcome ]
798806 W_col = conf_list if conf_list else []
799807 estimates = estimate_drl (
800- df , treatment , outcome , X_col , W_col ,
801- control_value , treatment_value ,
808+ df ,
809+ treatment ,
810+ outcome ,
811+ X_col ,
812+ W_col ,
813+ control_value ,
814+ treatment_value ,
802815 is_linear = is_linear ,
803816 treatment_kind = treatment_kind ,
804817 )
@@ -807,31 +820,37 @@ def estimate_effect(
807820 # Pick learner variant based on data (matches original)
808821 learner = "t" if is_linear else "x"
809822 estimates = estimate_metalearner (
810- df , treatment , outcome , X_col ,
811- control_value , treatment_value ,
823+ df ,
824+ treatment ,
825+ outcome ,
826+ X_col ,
827+ control_value ,
828+ treatment_value ,
812829 learner = learner ,
813830 )
814831 elif selected_method == "iv" :
815832 if not instrument :
816833 raise ValueError (
817- "No valid instrument found in graph. "
818- "Provide instrument= or use a different method."
834+ "No valid instrument found in graph. Provide instrument= or use a different method."
819835 )
820836 X_col = [c for c in names if c not in (treatment , outcome , instrument )]
821837 W_col = conf_list if conf_list else []
822838 estimates = estimate_iv (
823- df , treatment , outcome , instrument , X_col , W_col ,
824- control_value , treatment_value ,
839+ df ,
840+ treatment ,
841+ outcome ,
842+ instrument ,
843+ X_col ,
844+ W_col ,
845+ control_value ,
846+ treatment_value ,
825847 )
826848 else :
827849 raise ValueError (
828- f"Unknown method '{ selected_method } '. "
829- "Use: linear, matching, dml, drl, metalearner, iv."
850+ f"Unknown method '{ selected_method } '. Use: linear, matching, dml, drl, metalearner, iv."
830851 )
831852 except ImportError as e :
832- raise ImportError (
833- "Estimation requires inference extras: pip install causal-copilot[inference]"
834- ) from e
853+ raise ImportError ("Estimation requires inference extras: pip install causal-copilot[inference]" ) from e
835854
836855 # --- Build TreatmentEffect ---
837856 ate_info = estimates .get ("ate" , {})
@@ -961,8 +980,12 @@ def refute_estimate(
961980 from causal_copilot .mcp .estimation import run_refutation
962981
963982 return run_refutation (
964- df , dot_graph , treatment , outcome ,
965- control_value , treatment_value ,
983+ df ,
984+ dot_graph ,
985+ treatment ,
986+ outcome ,
987+ control_value ,
988+ treatment_value ,
966989 confounders = conf_list ,
967990 shap_top_feature = shap_top ,
968991 )
@@ -1003,11 +1026,7 @@ def _to_dag(self, adj: np.ndarray) -> np.ndarray:
10031026 # Orient undirected edges: lower column index → higher column index
10041027 for i in range (n ):
10051028 for j in range (i + 1 , n ):
1006- has_undirected = (
1007- (adj [i , j ] == 2 or adj [j , i ] == 2 )
1008- and dag [i , j ] == 0
1009- and dag [j , i ] == 0
1010- )
1029+ has_undirected = (adj [i , j ] == 2 or adj [j , i ] == 2 ) and dag [i , j ] == 0 and dag [j , i ] == 0
10111030 if has_undirected :
10121031 dag [j , i ] = 1 # i→j (adj[j,i]=1 means i causes j)
10131032
@@ -1045,14 +1064,8 @@ def inspect_graph(
10451064
10461065 # Edge statistics
10471066 n_directed = int (np .sum (adj == 1 ))
1048- n_undirected = sum (
1049- 1 for i in range (n ) for j in range (i + 1 , n )
1050- if adj [i , j ] == 2 or adj [j , i ] == 2
1051- )
1052- n_bidirected = sum (
1053- 1 for i in range (n ) for j in range (i + 1 , n )
1054- if adj [i , j ] == 3 or adj [j , i ] == 3
1055- )
1067+ n_undirected = sum (1 for i in range (n ) for j in range (i + 1 , n ) if adj [i , j ] == 2 or adj [j , i ] == 2 )
1068+ n_bidirected = sum (1 for i in range (n ) for j in range (i + 1 , n ) if adj [i , j ] == 3 or adj [j , i ] == 3 )
10561069
10571070 # Inference policy
10581071 props = self ._last_properties or {}
@@ -1132,8 +1145,13 @@ def estimate_counterfactual(
11321145 from causal_copilot .mcp .estimation import run_counterfactual
11331146
11341147 return run_counterfactual (
1135- df , dag_adj , names , treatment , outcome ,
1136- intervention_value , observed_row_index ,
1148+ df ,
1149+ dag_adj ,
1150+ names ,
1151+ treatment ,
1152+ outcome ,
1153+ intervention_value ,
1154+ observed_row_index ,
11371155 )
11381156
11391157 def simulate_intervention (
@@ -1169,8 +1187,14 @@ def simulate_intervention(
11691187 from causal_copilot .mcp .estimation import run_intervention_simulation
11701188
11711189 return run_intervention_simulation (
1172- df , dag_adj , names , treatment , outcome ,
1173- intervention_value , shift , n_samples ,
1190+ df ,
1191+ dag_adj ,
1192+ names ,
1193+ treatment ,
1194+ outcome ,
1195+ intervention_value ,
1196+ shift ,
1197+ n_samples ,
11741198 )
11751199
11761200 def attribute_anomaly (
@@ -1202,8 +1226,12 @@ def attribute_anomaly(
12021226 from causal_copilot .mcp .estimation import run_anomaly_attribution
12031227
12041228 return run_anomaly_attribution (
1205- df , dag_adj , names , target_node ,
1206- threshold_percentile , n_samples ,
1229+ df ,
1230+ dag_adj ,
1231+ names ,
1232+ target_node ,
1233+ threshold_percentile ,
1234+ n_samples ,
12071235 )
12081236
12091237 def attribute_distribution_change (
@@ -1240,7 +1268,11 @@ def attribute_distribution_change(
12401268 from causal_copilot .mcp .estimation import run_distribution_change
12411269
12421270 return run_distribution_change (
1243- df_old , data_new , dag_adj , names , target_node ,
1271+ df_old ,
1272+ data_new ,
1273+ dag_adj ,
1274+ names ,
1275+ target_node ,
12441276 )
12451277
12461278 # --- Analysis & validation ---
0 commit comments