Skip to content

Commit 3a06e99

Browse files
cauchyturingclaude
andcommitted
style: ruff format all modified files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7329fce commit 3a06e99

File tree

6 files changed

+158
-84
lines changed

6 files changed

+158
-84
lines changed

causal_copilot/copilot.py

Lines changed: 96 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,10 @@ def estimate_effect(
660660

661661
# --- Treatment type dispatch (4 cases, matching original) ---
662662
_, T0, T1, treatment_kind = prepare_treatment(
663-
df, treatment, T0=control_value, T1=treatment_value,
663+
df,
664+
treatment,
665+
T0=control_value,
666+
T1=treatment_value,
664667
)
665668
control_value = T0
666669
treatment_value = T1
@@ -674,9 +677,7 @@ def estimate_effect(
674677
is_lg = is_linear and is_gaussian
675678
policy = check_inference_policy(adj, is_linear_gaussian=is_lg)
676679
if not policy["allow_inference"]:
677-
result.warnings = warnings_list + [
678-
f"Inference rejected: {policy['reason']}"
679-
]
680+
result.warnings = warnings_list + [f"Inference rejected: {policy['reason']}"]
680681
result.summary += f" Inference rejected: {policy['reason']}"
681682
return result
682683
warnings_list.append("CPDAG: undirected edges dropped for estimation")
@@ -701,8 +702,7 @@ def estimate_effect(
701702
if had_edge and clean_adj[o_idx, t_idx] == 0:
702703
clean_adj[o_idx, t_idx] = 1
703704
warnings_list.append(
704-
f"Restored {treatment}->{outcome} as directed for estimation "
705-
f"(was undirected in CPDAG)"
705+
f"Restored {treatment}->{outcome} as directed for estimation (was undirected in CPDAG)"
706706
)
707707

708708
# --- Determine confounders (full adj matrix type check) ---
@@ -711,14 +711,15 @@ def estimate_effect(
711711
potential_conf = []
712712
else:
713713
conf_list, potential_conf = identify_confounders(
714-
adj, names, treatment, outcome,
714+
adj,
715+
names,
716+
treatment,
717+
outcome,
715718
)
716719
if potential_conf and not conf_list:
717720
# Use potential confounders when no confirmed ones exist
718721
conf_list = potential_conf
719-
warnings_list.append(
720-
f"Using potential confounders from undirected edges: {conf_list}"
721-
)
722+
warnings_list.append(f"Using potential confounders from undirected edges: {conf_list}")
722723

723724
# --- Check for IV ---
724725
has_iv = bool(instrument)
@@ -730,9 +731,7 @@ def estimate_effect(
730731
continue
731732
if clean_adj[o_idx, z_idx] == 1:
732733
continue
733-
has_parents = any(
734-
clean_adj[z_idx, j] == 1 for j in range(n) if j != z_idx
735-
)
734+
has_parents = any(clean_adj[z_idx, j] == 1 for j in range(n) if j != z_idx)
736735
if not has_parents:
737736
has_iv = True
738737
if not instrument:
@@ -742,7 +741,9 @@ def estimate_effect(
742741
# --- Auto-select method (data-driven, replaces LLM Filter) ---
743742
if method is None:
744743
selected_method = select_estimation_method(
745-
df, treatment, treatment_kind,
744+
df,
745+
treatment,
746+
treatment_kind,
746747
is_linear=is_linear,
747748
is_gaussian=is_gaussian,
748749
n_features=len(names) - 1,
@@ -765,40 +766,52 @@ def estimate_effect(
765766
if selected_method == "linear":
766767
dot_graph = _adj_to_dot(clean_adj, names)
767768
estimates = estimate_linear(
768-
df, dot_graph, treatment, outcome,
769-
control_value, treatment_value,
769+
df,
770+
dot_graph,
771+
treatment,
772+
outcome,
773+
control_value,
774+
treatment_value,
770775
)
771776
elif selected_method == "matching":
772-
match_conf = conf_list if conf_list else [
773-
c for c in names if c != treatment and c != outcome
774-
]
777+
match_conf = conf_list if conf_list else [c for c in names if c != treatment and c != outcome]
775778
estimates = estimate_matching(
776-
df, treatment, outcome, match_conf,
777-
int(control_value), int(treatment_value),
779+
df,
780+
treatment,
781+
outcome,
782+
match_conf,
783+
int(control_value),
784+
int(treatment_value),
778785
)
779786
elif selected_method == "dml":
780787
X_col = [c for c in names if c != treatment and c != outcome and c not in conf_list]
781788
if not X_col:
782-
X_col = conf_list[:] if conf_list else [
783-
c for c in names if c != treatment and c != outcome
784-
]
789+
X_col = conf_list[:] if conf_list else [c for c in names if c != treatment and c != outcome]
785790
W_col = conf_list if conf_list else []
786791
estimates = estimate_dml(
787-
df, treatment, outcome, X_col, W_col,
788-
control_value, treatment_value,
792+
df,
793+
treatment,
794+
outcome,
795+
X_col,
796+
W_col,
797+
control_value,
798+
treatment_value,
789799
is_linear=is_linear,
790800
treatment_kind=treatment_kind,
791801
)
792802
elif selected_method == "drl":
793803
X_col = [c for c in names if c != treatment and c != outcome and c not in conf_list]
794804
if not X_col:
795-
X_col = conf_list[:] if conf_list else [
796-
c for c in names if c != treatment and c != outcome
797-
]
805+
X_col = conf_list[:] if conf_list else [c for c in names if c != treatment and c != outcome]
798806
W_col = conf_list if conf_list else []
799807
estimates = estimate_drl(
800-
df, treatment, outcome, X_col, W_col,
801-
control_value, treatment_value,
808+
df,
809+
treatment,
810+
outcome,
811+
X_col,
812+
W_col,
813+
control_value,
814+
treatment_value,
802815
is_linear=is_linear,
803816
treatment_kind=treatment_kind,
804817
)
@@ -807,31 +820,37 @@ def estimate_effect(
807820
# Pick learner variant based on data (matches original)
808821
learner = "t" if is_linear else "x"
809822
estimates = estimate_metalearner(
810-
df, treatment, outcome, X_col,
811-
control_value, treatment_value,
823+
df,
824+
treatment,
825+
outcome,
826+
X_col,
827+
control_value,
828+
treatment_value,
812829
learner=learner,
813830
)
814831
elif selected_method == "iv":
815832
if not instrument:
816833
raise ValueError(
817-
"No valid instrument found in graph. "
818-
"Provide instrument= or use a different method."
834+
"No valid instrument found in graph. Provide instrument= or use a different method."
819835
)
820836
X_col = [c for c in names if c not in (treatment, outcome, instrument)]
821837
W_col = conf_list if conf_list else []
822838
estimates = estimate_iv(
823-
df, treatment, outcome, instrument, X_col, W_col,
824-
control_value, treatment_value,
839+
df,
840+
treatment,
841+
outcome,
842+
instrument,
843+
X_col,
844+
W_col,
845+
control_value,
846+
treatment_value,
825847
)
826848
else:
827849
raise ValueError(
828-
f"Unknown method '{selected_method}'. "
829-
"Use: linear, matching, dml, drl, metalearner, iv."
850+
f"Unknown method '{selected_method}'. Use: linear, matching, dml, drl, metalearner, iv."
830851
)
831852
except ImportError as e:
832-
raise ImportError(
833-
"Estimation requires inference extras: pip install causal-copilot[inference]"
834-
) from e
853+
raise ImportError("Estimation requires inference extras: pip install causal-copilot[inference]") from e
835854

836855
# --- Build TreatmentEffect ---
837856
ate_info = estimates.get("ate", {})
@@ -961,8 +980,12 @@ def refute_estimate(
961980
from causal_copilot.mcp.estimation import run_refutation
962981

963982
return run_refutation(
964-
df, dot_graph, treatment, outcome,
965-
control_value, treatment_value,
983+
df,
984+
dot_graph,
985+
treatment,
986+
outcome,
987+
control_value,
988+
treatment_value,
966989
confounders=conf_list,
967990
shap_top_feature=shap_top,
968991
)
@@ -1003,11 +1026,7 @@ def _to_dag(self, adj: np.ndarray) -> np.ndarray:
10031026
# Orient undirected edges: lower column index → higher column index
10041027
for i in range(n):
10051028
for j in range(i + 1, n):
1006-
has_undirected = (
1007-
(adj[i, j] == 2 or adj[j, i] == 2)
1008-
and dag[i, j] == 0
1009-
and dag[j, i] == 0
1010-
)
1029+
has_undirected = (adj[i, j] == 2 or adj[j, i] == 2) and dag[i, j] == 0 and dag[j, i] == 0
10111030
if has_undirected:
10121031
dag[j, i] = 1 # i→j (adj[j,i]=1 means i causes j)
10131032

@@ -1045,14 +1064,8 @@ def inspect_graph(
10451064

10461065
# Edge statistics
10471066
n_directed = int(np.sum(adj == 1))
1048-
n_undirected = sum(
1049-
1 for i in range(n) for j in range(i + 1, n)
1050-
if adj[i, j] == 2 or adj[j, i] == 2
1051-
)
1052-
n_bidirected = sum(
1053-
1 for i in range(n) for j in range(i + 1, n)
1054-
if adj[i, j] == 3 or adj[j, i] == 3
1055-
)
1067+
n_undirected = sum(1 for i in range(n) for j in range(i + 1, n) if adj[i, j] == 2 or adj[j, i] == 2)
1068+
n_bidirected = sum(1 for i in range(n) for j in range(i + 1, n) if adj[i, j] == 3 or adj[j, i] == 3)
10561069

10571070
# Inference policy
10581071
props = self._last_properties or {}
@@ -1132,8 +1145,13 @@ def estimate_counterfactual(
11321145
from causal_copilot.mcp.estimation import run_counterfactual
11331146

11341147
return run_counterfactual(
1135-
df, dag_adj, names, treatment, outcome,
1136-
intervention_value, observed_row_index,
1148+
df,
1149+
dag_adj,
1150+
names,
1151+
treatment,
1152+
outcome,
1153+
intervention_value,
1154+
observed_row_index,
11371155
)
11381156

11391157
def simulate_intervention(
@@ -1169,8 +1187,14 @@ def simulate_intervention(
11691187
from causal_copilot.mcp.estimation import run_intervention_simulation
11701188

11711189
return run_intervention_simulation(
1172-
df, dag_adj, names, treatment, outcome,
1173-
intervention_value, shift, n_samples,
1190+
df,
1191+
dag_adj,
1192+
names,
1193+
treatment,
1194+
outcome,
1195+
intervention_value,
1196+
shift,
1197+
n_samples,
11741198
)
11751199

11761200
def attribute_anomaly(
@@ -1202,8 +1226,12 @@ def attribute_anomaly(
12021226
from causal_copilot.mcp.estimation import run_anomaly_attribution
12031227

12041228
return run_anomaly_attribution(
1205-
df, dag_adj, names, target_node,
1206-
threshold_percentile, n_samples,
1229+
df,
1230+
dag_adj,
1231+
names,
1232+
target_node,
1233+
threshold_percentile,
1234+
n_samples,
12071235
)
12081236

12091237
def attribute_distribution_change(
@@ -1240,7 +1268,11 @@ def attribute_distribution_change(
12401268
from causal_copilot.mcp.estimation import run_distribution_change
12411269

12421270
return run_distribution_change(
1243-
df_old, data_new, dag_adj, names, target_node,
1271+
df_old,
1272+
data_new,
1273+
dag_adj,
1274+
names,
1275+
target_node,
12441276
)
12451277

12461278
# --- Analysis & validation ---

causal_copilot/mcp/estimation.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def estimate_dml(
146146
from causal_copilot.mcp.offline import get_default_estimation_config
147147

148148
config = get_default_estimation_config(
149-
"dml", data, treatment,
149+
"dml",
150+
data,
151+
treatment,
150152
outcome=outcome,
151153
is_linear=is_linear,
152154
treatment_kind=treatment_kind,
@@ -240,7 +242,9 @@ def estimate_drl(
240242
from causal_copilot.mcp.offline import get_default_estimation_config
241243

242244
config = get_default_estimation_config(
243-
"drl", data, treatment,
245+
"drl",
246+
data,
247+
treatment,
244248
outcome=outcome,
245249
is_linear=is_linear,
246250
treatment_kind=treatment_kind,
@@ -372,21 +376,26 @@ def estimate_metalearner(
372376
elif learner == "x":
373377
try:
374378
from xgboost import XGBRegressor
379+
375380
base_model = XGBRegressor(objective="reg:squarederror", n_estimators=100)
376381
except ImportError:
377382
from sklearn.ensemble import GradientBoostingRegressor
383+
378384
base_model = GradientBoostingRegressor(n_estimators=100)
379385
model = XLearner(
380386
models=base_model,
381387
propensity_model=LogisticRegression(max_iter=1000),
382388
)
383389
elif learner == "da":
384390
from econml.metalearners import DomainAdaptationLearner
391+
385392
try:
386393
from xgboost import XGBRegressor
394+
387395
base_model = XGBRegressor(objective="reg:squarederror", n_estimators=100)
388396
except ImportError:
389397
from sklearn.ensemble import GradientBoostingRegressor
398+
390399
base_model = GradientBoostingRegressor(n_estimators=100)
391400
model = DomainAdaptationLearner(
392401
models=base_model,

causal_copilot/mcp/offline.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def _pick_model(is_binary: bool) -> object:
258258
config = {
259259
"algo": algo,
260260
"model_regression": _pick_model(is_binary=False), # outcome model
261-
"model_propensity": _pick_model(is_binary=True), # treatment model
261+
"model_propensity": _pick_model(is_binary=True), # treatment model
262262
}
263263
return config
264264

@@ -307,8 +307,11 @@ def get_default_estimation_config(
307307

308308
if method == "dml":
309309
algo = select_dml_variant(
310-
data, treatment, treatment_kind,
311-
is_linear=is_linear, n_features=n_features,
310+
data,
311+
treatment,
312+
treatment_kind,
313+
is_linear=is_linear,
314+
n_features=n_features,
312315
)
313316
elif method == "drl":
314317
if not is_linear and treatment_kind in ("binary", "discrete"):
@@ -325,6 +328,10 @@ def get_default_estimation_config(
325328
outcome_col = outcome or ([c for c in data.columns if c != treatment][0])
326329

327330
return select_models_for_method(
328-
method, algo, data, treatment, outcome_col,
331+
method,
332+
algo,
333+
data,
334+
treatment,
335+
outcome_col,
329336
is_linear=is_linear,
330337
)

0 commit comments

Comments
 (0)