Skip to content

Commit a324d4a

Browse files
Address SHAP issues
1 parent 97609e0 commit a324d4a

31 files changed

+819
-411
lines changed

bluecast/ai/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,26 @@ def _create_provider(config: AIConfig):
3535

3636
if config.provider == "gemini":
3737
from bluecast.ai.providers.gemini import GeminiProvider
38-
return GeminiProvider(api_key=config.api_key, model=model, temperature=config.temperature)
38+
39+
return GeminiProvider(
40+
api_key=config.api_key, model=model, temperature=config.temperature
41+
)
3942
elif config.provider == "openai":
4043
from bluecast.ai.providers.openai_provider import OpenAIProvider
41-
return OpenAIProvider(api_key=config.api_key, model=model, temperature=config.temperature)
44+
45+
return OpenAIProvider(
46+
api_key=config.api_key, model=model, temperature=config.temperature
47+
)
4248
elif config.provider == "anthropic":
4349
from bluecast.ai.providers.anthropic_provider import AnthropicProvider
44-
return AnthropicProvider(api_key=config.api_key, model=model, temperature=config.temperature)
50+
51+
return AnthropicProvider(
52+
api_key=config.api_key, model=model, temperature=config.temperature
53+
)
4554
else:
46-
raise ValueError(f"Unknown provider: {config.provider}. Use 'gemini', 'openai', or 'anthropic'.")
55+
raise ValueError(
56+
f"Unknown provider: {config.provider}. Use 'gemini', 'openai', or 'anthropic'."
57+
)
4758

4859

4960
class BlueCastAI:

bluecast/ai/agents/base.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from typing import Any, Callable, Dict, List, Optional
77

88
from bluecast.ai.context import SharedContext
9-
from bluecast.ai.providers.base import BaseLLMProvider, LLMResponse, Message, ToolDefinition
9+
from bluecast.ai.providers.base import (
10+
BaseLLMProvider,
11+
LLMResponse,
12+
Message,
13+
ToolDefinition,
14+
)
1015

1116
logger = logging.getLogger(__name__)
1217

@@ -36,15 +41,15 @@ def __init__(
3641
@property
3742
@abstractmethod
3843
def name(self) -> str:
39-
...
44+
pass
4045

4146
@abstractmethod
4247
def system_prompt(self) -> str:
43-
...
48+
pass
4449

4550
@abstractmethod
4651
def get_tools(self) -> List[ToolDefinition]:
47-
...
52+
pass
4853

4954
def register_tool_impl(self, name: str, func: Callable) -> None:
5055
self._tool_implementations[name] = func
@@ -75,7 +80,9 @@ def run(self, task: str) -> str:
7580
print(f" [{self.name}] Starting: {task[:80]}...")
7681

7782
self.context.log(
78-
self.name, task[:500], event_type="task",
83+
self.name,
84+
task[:500],
85+
event_type="task",
7986
metadata={"full_task_length": len(task)},
8087
)
8188

@@ -87,7 +94,7 @@ def run(self, task: str) -> str:
8794
tools = self.get_tools()
8895
response: Optional[LLMResponse] = None
8996

90-
for iteration in range(MAX_TOOL_ITERATIONS):
97+
for _iteration in range(MAX_TOOL_ITERATIONS):
9198
response = self.llm.chat(messages, tools=tools if tools else None)
9299

93100
if response.has_tool_calls:
@@ -145,7 +152,9 @@ def run(self, task: str) -> str:
145152

146153
final_text = response.text if response else "Agent reached max tool iterations."
147154
self.context.log(
148-
self.name, final_text[:300], event_type="error",
155+
self.name,
156+
final_text[:300],
157+
event_type="error",
149158
metadata={"reason": "max_iterations_reached"},
150159
)
151160
return final_text

bluecast/ai/agents/feature_engineer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Feature engineer agent: creates new features based on data analysis."""
22

3-
import json
43
from typing import List
54

65
from bluecast.ai.agents.base import BaseAgent
@@ -19,8 +18,15 @@ def __init__(self, *args, **kwargs):
1918
def _create_feature_wrapper(self, feature_code: str, description: str = "", **kw):
2019
if self.context.engineered_df is not None:
2120
df = self.context.engineered_df
22-
else:
21+
elif self.context.df_train is not None:
2322
df = self.context.df_train.copy()
23+
else:
24+
return {
25+
"success": False,
26+
"new_columns": [],
27+
"shape": [],
28+
"error": "No training data available.",
29+
}
2430

2531
result = tool_create_feature(df, feature_code)
2632
if result["success"]:
@@ -40,7 +46,9 @@ def system_prompt(self) -> str:
4046
profile = self.context.data_profile or "Not yet profiled."
4147
hints = ""
4248
if self.context.data_warnings:
43-
hints = "\nData warnings:\n" + "\n".join(f"- {w}" for w in self.context.data_warnings)
49+
hints = "\nData warnings:\n" + "\n".join(
50+
f"- {w}" for w in self.context.data_warnings
51+
)
4452

4553
return f"""You are a feature engineer for the BlueCast AutoML framework.
4654

bluecast/ai/agents/pipeline_builder.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Pipeline builder agent: generates and runs BlueCast pipelines."""
22

3-
import json
43
from typing import List
54

65
from bluecast.ai.agents.base import BaseAgent
@@ -17,7 +16,11 @@ def __init__(self, *args, **kwargs):
1716
)
1817

1918
def _build_wrapper(self, **config):
20-
df = self.context.engineered_df if self.context.engineered_df is not None else self.context.df_train
19+
df = (
20+
self.context.engineered_df
21+
if self.context.engineered_df is not None
22+
else self.context.df_train
23+
)
2124
result = tool_build_and_run_pipeline(df, self.context.target_col, config)
2225

2326
run_record = {
@@ -77,7 +80,9 @@ def _generate_pipeline_code(self, config: dict) -> None:
7780

7881
strategy = config.get("ensemble_strategy", "mean")
7982
if config.get("use_cv", True):
80-
lines.append(f"ensemble_config = EnsembleConfig(ensemble_strategy=\"{strategy}\")")
83+
lines.append(
84+
f'ensemble_config = EnsembleConfig(ensemble_strategy="{strategy}")'
85+
)
8186
lines.append("")
8287

8388
lines.append("pipeline = BlueCastAuto(")
@@ -88,7 +93,9 @@ def _generate_pipeline_code(self, config: dict) -> None:
8893
lines.append(" ensemble_config=ensemble_config,")
8994
lines.append(")")
9095
lines.append("")
91-
lines.append(f"pipeline.fit_eval(df_train, target_col=\"{self.context.target_col}\")")
96+
lines.append(
97+
f'pipeline.fit_eval(df_train, target_col="{self.context.target_col}")'
98+
)
9299

93100
self.context.pipeline_code = "\n".join(lines)
94101

bluecast/ai/agents/reporter.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,20 @@ def build_report_task(self) -> str:
6767
)
6868

6969
if self.context.data_profile:
70-
sections.append(f"\nData profile:\n{json.dumps(self.context.data_profile, indent=2, default=str)[:3000]}")
70+
sections.append(
71+
f"\nData profile:\n{json.dumps(self.context.data_profile, indent=2, default=str)[:3000]}"
72+
)
7173

7274
if self.context.data_warnings:
73-
sections.append("\nWarnings:\n" + "\n".join(f"- {w}" for w in self.context.data_warnings))
75+
sections.append(
76+
"\nWarnings:\n"
77+
+ "\n".join(f"- {w}" for w in self.context.data_warnings)
78+
)
7479

7580
if self.context.feature_engineering_code:
76-
sections.append(f"\nFeature engineering code:\n```python\n{self.context.feature_engineering_code}\n```")
81+
sections.append(
82+
f"\nFeature engineering code:\n```python\n{self.context.feature_engineering_code}\n```"
83+
)
7784
else:
7885
sections.append("\nNo feature engineering was applied.")
7986

@@ -90,10 +97,14 @@ def build_report_task(self) -> str:
9097
sections.append(f"\nBest metrics: {self.context.best_metrics}")
9198

9299
if self.context.web_research:
93-
sections.append(f"\nWeb research findings:\n{self.context.web_research[:1000]}")
100+
sections.append(
101+
f"\nWeb research findings:\n{self.context.web_research[:1000]}"
102+
)
94103

95104
if self.context.pipeline_code:
96-
sections.append(f"\nGenerated pipeline code:\n```python\n{self.context.pipeline_code}\n```")
105+
sections.append(
106+
f"\nGenerated pipeline code:\n```python\n{self.context.pipeline_code}\n```"
107+
)
97108

98109
# Include a selection of the structured log
99110
log_entries = self.context.structured_log[-30:]

bluecast/ai/agents/researcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
class ResearcherAgent(BaseAgent):
1111
def __init__(self, *args, **kwargs):
1212
super().__init__(*args, **kwargs)
13-
self.register_tool_impl("web_search", lambda query, **kw: tool_web_search(query))
13+
self.register_tool_impl(
14+
"web_search", lambda query, **kw: tool_web_search(query)
15+
)
1416

1517
@property
1618
def name(self) -> str:

bluecast/ai/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_data_summary(self) -> str:
116116
df = self.get_working_df()
117117
lines = []
118118

119-
if self.was_sampled:
119+
if self.was_sampled and self.original_shape is not None:
120120
lines.append(
121121
f"NOTE: Working on a stratified sample of {df.shape[0]} rows "
122122
f"(original: {self.original_shape[0]} rows x {self.original_shape[1]} cols). "

bluecast/ai/orchestrator.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Optional
88

99
import dill
10-
import numpy as np
1110
import pandas as pd
1211

1312
from bluecast.ai.agents.data_analyst import DataAnalystAgent
@@ -116,14 +115,20 @@ def _apply_smart_sampling(self) -> None:
116115
target = self.context.target_col
117116
if target in df.columns and df[target].nunique() <= 20:
118117
# Stratified sampling for classification
119-
sample_df = df.groupby(target, group_keys=False).apply(
120-
lambda x: x.sample(
121-
n=min(len(x), max(1, int(max_rows * len(x) / n_rows))),
122-
random_state=42,
118+
sample_df = (
119+
df.groupby(target, group_keys=False)
120+
.apply(
121+
lambda x: x.sample(
122+
n=min(len(x), max(1, int(max_rows * len(x) / n_rows))),
123+
random_state=42,
124+
)
123125
)
124-
).reset_index(drop=True)
126+
.reset_index(drop=True)
127+
)
125128
else:
126-
sample_df = df.sample(n=max_rows, random_state=42).reset_index(drop=True)
129+
sample_df = df.sample(n=max_rows, random_state=42).reset_index(
130+
drop=True
131+
)
127132

128133
msg = (
129134
f"Dataset sampled: {n_rows} -> {len(sample_df)} rows "
@@ -167,7 +172,8 @@ def _save_checkpoint(self, step_name: str) -> None:
167172
with open(path, "wb") as f:
168173
dill.dump(self.context, f)
169174
self.context.log(
170-
"Orchestrator", f"Checkpoint saved after '{step_name}'",
175+
"Orchestrator",
176+
f"Checkpoint saved after '{step_name}'",
171177
event_type="checkpoint",
172178
)
173179
if self.config.verbose:
@@ -202,8 +208,13 @@ def _load_checkpoint(self) -> bool:
202208

203209
# Re-attach context to all agents
204210
for agent in [
205-
self.planner, self.analyst, self.engineer,
206-
self.builder, self.evaluator, self.researcher, self.reporter,
211+
self.planner,
212+
self.analyst,
213+
self.engineer,
214+
self.builder,
215+
self.evaluator,
216+
self.researcher,
217+
self.reporter,
207218
]:
208219
agent.context = self.context
209220

@@ -233,7 +244,7 @@ def run(self) -> BlueCastAIResult:
233244
print("BlueCastAI - Multi-Agent AutoML Pipeline")
234245
print("=" * 60)
235246

236-
resumed = self._load_checkpoint()
247+
self._load_checkpoint()
237248

238249
# --- Step 0: Smart sampling ---
239250
if not self._is_step_done("sampling"):
@@ -273,7 +284,9 @@ def run(self) -> BlueCastAIResult:
273284

274285
# --- Step 5: Build-Evaluate-Improve loop ---
275286
if not self._is_step_done("build_loop"):
276-
max_iterations = plan.get("max_iterations", self.config.get_max_iterations())
287+
max_iterations = plan.get(
288+
"max_iterations", self.config.get_max_iterations()
289+
)
277290
self._step_build_loop(plan, max_iterations)
278291
self._save_checkpoint("build_loop")
279292

@@ -324,8 +337,10 @@ def _step_plan(self) -> dict:
324337

325338
self.context.class_problem = plan.get("class_problem", "binary")
326339
self.context.log(
327-
"Orchestrator", f"Plan: {json.dumps(plan, indent=2)}",
328-
event_type="plan", metadata={"plan": plan},
340+
"Orchestrator",
341+
f"Plan: {json.dumps(plan, indent=2)}",
342+
event_type="plan",
343+
metadata={"plan": plan},
329344
)
330345

331346
if self.config.verbose:
@@ -341,7 +356,11 @@ def _step_plan(self) -> dict:
341356
def _reconstruct_plan(self) -> dict:
342357
"""Reconstruct the plan from structured log metadata."""
343358
for entry in self.context.structured_log:
344-
if entry.event_type == "plan" and entry.metadata and "plan" in entry.metadata:
359+
if (
360+
entry.event_type == "plan"
361+
and entry.metadata
362+
and "plan" in entry.metadata
363+
):
345364
return entry.metadata["plan"]
346365
return self.planner._default_plan()
347366

@@ -362,7 +381,14 @@ def _step_analyze(self) -> None:
362381
)
363382
self.context.data_profile = {"summary": result}
364383

365-
for keyword in ["leakage", "imbalance", "missing", "null", "duplicate", "constant"]:
384+
for keyword in [
385+
"leakage",
386+
"imbalance",
387+
"missing",
388+
"null",
389+
"duplicate",
390+
"constant",
391+
]:
366392
if keyword in result.lower():
367393
self.context.data_warnings.append(
368394
f"Data analyst flagged: {keyword} detected"
@@ -373,7 +399,9 @@ def _step_feature_engineer(self, plan: dict) -> None:
373399
print("\nStep 4: Engineering features...")
374400

375401
hints = plan.get("feature_engineering_hints", [])
376-
hint_text = "\n".join(f"- {h}" for h in hints) if hints else "Use your judgment."
402+
hint_text = (
403+
"\n".join(f"- {h}" for h in hints) if hints else "Use your judgment."
404+
)
377405

378406
task = (
379407
f"Create useful features for this {self.context.class_problem} problem.\n"
@@ -383,7 +411,11 @@ def _step_feature_engineer(self, plan: dict) -> None:
383411
self.engineer.run(task)
384412

385413
if self.context.engineered_df is not None and self.config.verbose:
386-
orig_cols = len(self.context.df_train.columns)
414+
orig_cols = (
415+
len(self.context.df_train.columns)
416+
if self.context.df_train is not None
417+
else 0
418+
)
387419
new_cols = len(self.context.engineered_df.columns)
388420
print(f" Features: {orig_cols} -> {new_cols} columns")
389421

@@ -398,7 +430,9 @@ def _step_build_loop(self, plan: dict, max_iterations: int) -> None:
398430
build_task = self._create_build_task(plan, iteration)
399431
self.builder.run(build_task)
400432

401-
latest_run = self.context.run_history[-1] if self.context.run_history else None
433+
latest_run = (
434+
self.context.run_history[-1] if self.context.run_history else None
435+
)
402436
if latest_run and self.config.verbose:
403437
status = "OK" if latest_run["success"] else "FAILED"
404438
print(f" Result [{status}]: {latest_run.get('metrics', {})}")

0 commit comments

Comments
 (0)