Skip to content

Commit c288b79

Browse files
Update test coverage
1 parent 9b59583 commit c288b79

File tree

5 files changed

+1030
-0
lines changed

5 files changed

+1030
-0
lines changed

bluecast/tests/test_ai_module.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
"""Tests for the AI module (config, context, tools, result) - no LLM API keys needed."""
2+
3+
import json
4+
import os
5+
import tempfile
6+
import time
7+
8+
import numpy as np
9+
import pandas as pd
10+
import pytest
11+
12+
from bluecast.ai.config import AIConfig
13+
from bluecast.ai.context import AgentLogEntry, SharedContext
14+
from bluecast.ai.result import BlueCastAIResult
15+
from bluecast.ai.tools import (
16+
TOOL_DEFINITIONS,
17+
_serialize_metrics,
18+
tool_check_correlations,
19+
tool_check_leakage,
20+
tool_create_feature,
21+
tool_describe_data,
22+
)
23+
24+
25+
# --- AIConfig ---
26+
class TestAIConfig:
27+
def test_defaults(self):
28+
config = AIConfig(api_key="test")
29+
assert config.provider == "gemini"
30+
assert config.temperature == 0.2
31+
assert config.max_rows_for_agents == 50_000
32+
assert config.checkpoint_dir is None
33+
34+
def test_get_model_name_default(self):
35+
assert AIConfig(api_key="t", provider="gemini").get_model_name() == "gemini-2.5-flash"
36+
assert AIConfig(api_key="t", provider="openai").get_model_name() == "gpt-4o"
37+
assert (
38+
AIConfig(api_key="t", provider="anthropic").get_model_name()
39+
== "claude-sonnet-4-20250514"
40+
)
41+
42+
def test_get_model_name_custom(self):
43+
config = AIConfig(api_key="t", model="custom-model")
44+
assert config.get_model_name() == "custom-model"
45+
46+
def test_get_max_iterations(self):
47+
assert AIConfig(api_key="t", max_iterations=7).get_max_iterations() == 7
48+
assert AIConfig(api_key="t", max_iterations=0, mode="fast").get_max_iterations() == 1
49+
assert (
50+
AIConfig(api_key="t", max_iterations=0, mode="precise").get_max_iterations() == 5
51+
)
52+
53+
54+
# --- AgentLogEntry ---
55+
class TestAgentLogEntry:
56+
def test_creation(self):
57+
entry = AgentLogEntry(
58+
timestamp=time.time(),
59+
agent="TestAgent",
60+
event_type="task",
61+
content="hello",
62+
metadata={"key": "val"},
63+
)
64+
assert entry.agent == "TestAgent"
65+
assert entry.event_type == "task"
66+
assert "TestAgent" in str(entry)
67+
assert "task" in str(entry)
68+
69+
70+
# --- SharedContext ---
71+
class TestSharedContext:
72+
def test_log(self):
73+
ctx = SharedContext()
74+
ctx.log("Agent1", "message", event_type="info", metadata={"x": 1})
75+
assert len(ctx.structured_log) == 1
76+
assert ctx.structured_log[0].agent == "Agent1"
77+
78+
def test_agent_log_backward_compat(self):
79+
ctx = SharedContext()
80+
ctx.log("A", "msg")
81+
assert len(ctx.agent_log) == 1
82+
assert "[A]" in ctx.agent_log[0]
83+
84+
def test_get_data_summary_empty(self):
85+
ctx = SharedContext()
86+
assert ctx.get_data_summary() == "No data loaded."
87+
88+
def test_get_data_summary_classification(self):
89+
df = pd.DataFrame({"a": [1, 2, 3], "target": [0, 1, 0]})
90+
ctx = SharedContext(df_train=df, target_col="target")
91+
summary = ctx.get_data_summary()
92+
assert "3 rows" in summary
93+
assert "target" in summary
94+
95+
def test_get_data_summary_regression(self):
96+
df = pd.DataFrame({"a": range(100), "target": np.random.randn(100)})
97+
ctx = SharedContext(df_train=df, target_col="target")
98+
summary = ctx.get_data_summary()
99+
assert "continuous" in summary
100+
101+
def test_get_data_summary_sampled(self):
102+
df = pd.DataFrame({"a": [1, 2], "target": [0, 1]})
103+
ctx = SharedContext(
104+
df_train=df,
105+
target_col="target",
106+
df_sample=df.head(1),
107+
was_sampled=True,
108+
original_shape=(1000, 5),
109+
)
110+
summary = ctx.get_data_summary()
111+
assert "sample" in summary.lower()
112+
113+
def test_get_working_df(self):
114+
df = pd.DataFrame({"a": [1, 2, 3]})
115+
sample = pd.DataFrame({"a": [1]})
116+
ctx = SharedContext(df_train=df, df_sample=sample)
117+
assert len(ctx.get_working_df()) == 1
118+
119+
def test_get_working_df_no_sample(self):
120+
df = pd.DataFrame({"a": [1, 2, 3]})
121+
ctx = SharedContext(df_train=df)
122+
assert len(ctx.get_working_df()) == 3
123+
124+
def test_get_working_df_no_data_raises(self):
125+
ctx = SharedContext()
126+
with pytest.raises(ValueError):
127+
ctx.get_working_df()
128+
129+
def test_get_full_df(self):
130+
df = pd.DataFrame({"a": [1, 2, 3]})
131+
ctx = SharedContext(df_train=df, df_sample=df.head(1))
132+
assert len(ctx.get_full_df()) == 3
133+
134+
135+
# --- Tools ---
136+
class TestTools:
137+
@pytest.fixture
138+
def sample_df(self):
139+
rng = np.random.default_rng(42)
140+
return pd.DataFrame(
141+
{
142+
"num1": rng.normal(0, 1, 100),
143+
"num2": rng.normal(5, 2, 100),
144+
"cat": rng.choice(["a", "b", "c"], 100),
145+
"target": rng.choice([0, 1], 100),
146+
}
147+
)
148+
149+
def test_describe_data(self, sample_df):
150+
result = tool_describe_data(sample_df, "target")
151+
assert "Shape" in result
152+
assert "target" in result.lower()
153+
assert "binary" in result.lower() or "2" in result
154+
155+
def test_check_correlations(self, sample_df):
156+
result = tool_check_correlations(sample_df, "target")
157+
assert "correlation" in result.lower() or "correlations" in result.lower()
158+
159+
def test_check_leakage(self, sample_df):
160+
result = tool_check_leakage(sample_df, "target")
161+
assert "leakage" in result.lower()
162+
163+
def test_create_feature_success(self, sample_df):
164+
result = tool_create_feature(sample_df, "df['ratio'] = df['num1'] / (df['num2'] + 1)")
165+
assert result["success"] is True
166+
assert "ratio" in result["new_columns"]
167+
168+
def test_create_feature_failure(self, sample_df):
169+
result = tool_create_feature(sample_df, "df['bad'] = df['nonexistent'] * 2")
170+
assert result["success"] is False
171+
assert result["error"] is not None
172+
173+
def test_tool_definitions(self):
174+
assert len(TOOL_DEFINITIONS) >= 5
175+
for name, td in TOOL_DEFINITIONS.items():
176+
assert td.name == name
177+
assert len(td.description) > 10
178+
assert "type" in td.parameters
179+
180+
def test_serialize_metrics_dict(self):
181+
m = {"roc_auc": 0.85, "accuracy": 0.9, "report": "text"}
182+
result = _serialize_metrics(m)
183+
assert result["roc_auc"] == 0.85
184+
185+
def test_serialize_metrics_tuple(self):
186+
result = _serialize_metrics((0.85, 0.02))
187+
assert result["oof_mean"] == 0.85
188+
189+
def test_serialize_metrics_numpy(self):
190+
m = {"score": np.float64(0.85)}
191+
result = _serialize_metrics(m)
192+
assert isinstance(result["score"], float)
193+
194+
195+
# --- BlueCastAIResult ---
196+
class TestBlueCastAIResult:
197+
def test_repr(self):
198+
result = BlueCastAIResult(metrics={"auc": 0.9})
199+
r = repr(result)
200+
assert "no pipeline" in r
201+
assert "auc" in r
202+
203+
def test_repr_with_report(self):
204+
result = BlueCastAIResult(report_markdown="# Report")
205+
assert "report=yes" in repr(result)
206+
207+
def test_predict_no_pipeline_raises(self):
208+
result = BlueCastAIResult()
209+
with pytest.raises(RuntimeError, match="No trained pipeline"):
210+
result.predict(pd.DataFrame())
211+
212+
def test_save_code(self):
213+
result = BlueCastAIResult(
214+
pipeline_code="pipeline.fit(df)",
215+
feature_engineering_code="df['new'] = 1",
216+
)
217+
with tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode="w") as f:
218+
path = f.name
219+
try:
220+
result.save_code(path)
221+
with open(path) as f:
222+
code = f.read()
223+
assert "pipeline.fit" in code
224+
assert "df['new']" in code
225+
assert "Auto-generated" in code
226+
finally:
227+
os.unlink(path)
228+
229+
def test_save_report(self):
230+
result = BlueCastAIResult(report_markdown="# My Report\nGreat results.")
231+
with tempfile.NamedTemporaryFile(suffix=".md", delete=False, mode="w") as f:
232+
path = f.name
233+
try:
234+
result.save_report(path)
235+
with open(path) as f:
236+
content = f.read()
237+
assert "# My Report" in content
238+
finally:
239+
os.unlink(path)
240+
241+
def test_save_log(self):
242+
entries = [
243+
AgentLogEntry(
244+
timestamp=1000.0, agent="A", event_type="task", content="hello"
245+
)
246+
]
247+
result = BlueCastAIResult(structured_log=entries)
248+
with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f:
249+
path = f.name
250+
try:
251+
result.save_log(path)
252+
with open(path) as f:
253+
data = json.load(f)
254+
assert len(data) == 1
255+
assert data[0]["agent"] == "A"
256+
finally:
257+
os.unlink(path)
258+
259+
def test_show_report_with_markdown(self, capsys):
260+
result = BlueCastAIResult(report_markdown="# Report Content")
261+
result.show_report()
262+
captured = capsys.readouterr()
263+
assert "# Report Content" in captured.out
264+
265+
def test_show_report_without_markdown(self, capsys):
266+
result = BlueCastAIResult(
267+
metrics={"roc_auc": 0.85},
268+
run_history=[{"success": True, "metrics": {"roc_auc": 0.85}}],
269+
class_problem="binary",
270+
)
271+
result.show_report()
272+
captured = capsys.readouterr()
273+
assert "binary" in captured.out
274+
assert "0.85" in captured.out

0 commit comments

Comments
 (0)