|
| 1 | +"""Tests for the AI module (config, context, tools, result) - no LLM API keys needed.""" |
| 2 | + |
| 3 | +import json |
| 4 | +import os |
| 5 | +import tempfile |
| 6 | +import time |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | +import pytest |
| 11 | + |
| 12 | +from bluecast.ai.config import AIConfig |
| 13 | +from bluecast.ai.context import AgentLogEntry, SharedContext |
| 14 | +from bluecast.ai.result import BlueCastAIResult |
| 15 | +from bluecast.ai.tools import ( |
| 16 | + TOOL_DEFINITIONS, |
| 17 | + _serialize_metrics, |
| 18 | + tool_check_correlations, |
| 19 | + tool_check_leakage, |
| 20 | + tool_create_feature, |
| 21 | + tool_describe_data, |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +# --- AIConfig --- |
| 26 | +class TestAIConfig: |
| 27 | + def test_defaults(self): |
| 28 | + config = AIConfig(api_key="test") |
| 29 | + assert config.provider == "gemini" |
| 30 | + assert config.temperature == 0.2 |
| 31 | + assert config.max_rows_for_agents == 50_000 |
| 32 | + assert config.checkpoint_dir is None |
| 33 | + |
| 34 | + def test_get_model_name_default(self): |
| 35 | + assert AIConfig(api_key="t", provider="gemini").get_model_name() == "gemini-2.5-flash" |
| 36 | + assert AIConfig(api_key="t", provider="openai").get_model_name() == "gpt-4o" |
| 37 | + assert ( |
| 38 | + AIConfig(api_key="t", provider="anthropic").get_model_name() |
| 39 | + == "claude-sonnet-4-20250514" |
| 40 | + ) |
| 41 | + |
| 42 | + def test_get_model_name_custom(self): |
| 43 | + config = AIConfig(api_key="t", model="custom-model") |
| 44 | + assert config.get_model_name() == "custom-model" |
| 45 | + |
| 46 | + def test_get_max_iterations(self): |
| 47 | + assert AIConfig(api_key="t", max_iterations=7).get_max_iterations() == 7 |
| 48 | + assert AIConfig(api_key="t", max_iterations=0, mode="fast").get_max_iterations() == 1 |
| 49 | + assert ( |
| 50 | + AIConfig(api_key="t", max_iterations=0, mode="precise").get_max_iterations() == 5 |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | +# --- AgentLogEntry --- |
| 55 | +class TestAgentLogEntry: |
| 56 | + def test_creation(self): |
| 57 | + entry = AgentLogEntry( |
| 58 | + timestamp=time.time(), |
| 59 | + agent="TestAgent", |
| 60 | + event_type="task", |
| 61 | + content="hello", |
| 62 | + metadata={"key": "val"}, |
| 63 | + ) |
| 64 | + assert entry.agent == "TestAgent" |
| 65 | + assert entry.event_type == "task" |
| 66 | + assert "TestAgent" in str(entry) |
| 67 | + assert "task" in str(entry) |
| 68 | + |
| 69 | + |
| 70 | +# --- SharedContext --- |
| 71 | +class TestSharedContext: |
| 72 | + def test_log(self): |
| 73 | + ctx = SharedContext() |
| 74 | + ctx.log("Agent1", "message", event_type="info", metadata={"x": 1}) |
| 75 | + assert len(ctx.structured_log) == 1 |
| 76 | + assert ctx.structured_log[0].agent == "Agent1" |
| 77 | + |
| 78 | + def test_agent_log_backward_compat(self): |
| 79 | + ctx = SharedContext() |
| 80 | + ctx.log("A", "msg") |
| 81 | + assert len(ctx.agent_log) == 1 |
| 82 | + assert "[A]" in ctx.agent_log[0] |
| 83 | + |
| 84 | + def test_get_data_summary_empty(self): |
| 85 | + ctx = SharedContext() |
| 86 | + assert ctx.get_data_summary() == "No data loaded." |
| 87 | + |
| 88 | + def test_get_data_summary_classification(self): |
| 89 | + df = pd.DataFrame({"a": [1, 2, 3], "target": [0, 1, 0]}) |
| 90 | + ctx = SharedContext(df_train=df, target_col="target") |
| 91 | + summary = ctx.get_data_summary() |
| 92 | + assert "3 rows" in summary |
| 93 | + assert "target" in summary |
| 94 | + |
| 95 | + def test_get_data_summary_regression(self): |
| 96 | + df = pd.DataFrame({"a": range(100), "target": np.random.randn(100)}) |
| 97 | + ctx = SharedContext(df_train=df, target_col="target") |
| 98 | + summary = ctx.get_data_summary() |
| 99 | + assert "continuous" in summary |
| 100 | + |
| 101 | + def test_get_data_summary_sampled(self): |
| 102 | + df = pd.DataFrame({"a": [1, 2], "target": [0, 1]}) |
| 103 | + ctx = SharedContext( |
| 104 | + df_train=df, |
| 105 | + target_col="target", |
| 106 | + df_sample=df.head(1), |
| 107 | + was_sampled=True, |
| 108 | + original_shape=(1000, 5), |
| 109 | + ) |
| 110 | + summary = ctx.get_data_summary() |
| 111 | + assert "sample" in summary.lower() |
| 112 | + |
| 113 | + def test_get_working_df(self): |
| 114 | + df = pd.DataFrame({"a": [1, 2, 3]}) |
| 115 | + sample = pd.DataFrame({"a": [1]}) |
| 116 | + ctx = SharedContext(df_train=df, df_sample=sample) |
| 117 | + assert len(ctx.get_working_df()) == 1 |
| 118 | + |
| 119 | + def test_get_working_df_no_sample(self): |
| 120 | + df = pd.DataFrame({"a": [1, 2, 3]}) |
| 121 | + ctx = SharedContext(df_train=df) |
| 122 | + assert len(ctx.get_working_df()) == 3 |
| 123 | + |
| 124 | + def test_get_working_df_no_data_raises(self): |
| 125 | + ctx = SharedContext() |
| 126 | + with pytest.raises(ValueError): |
| 127 | + ctx.get_working_df() |
| 128 | + |
| 129 | + def test_get_full_df(self): |
| 130 | + df = pd.DataFrame({"a": [1, 2, 3]}) |
| 131 | + ctx = SharedContext(df_train=df, df_sample=df.head(1)) |
| 132 | + assert len(ctx.get_full_df()) == 3 |
| 133 | + |
| 134 | + |
| 135 | +# --- Tools --- |
| 136 | +class TestTools: |
| 137 | + @pytest.fixture |
| 138 | + def sample_df(self): |
| 139 | + rng = np.random.default_rng(42) |
| 140 | + return pd.DataFrame( |
| 141 | + { |
| 142 | + "num1": rng.normal(0, 1, 100), |
| 143 | + "num2": rng.normal(5, 2, 100), |
| 144 | + "cat": rng.choice(["a", "b", "c"], 100), |
| 145 | + "target": rng.choice([0, 1], 100), |
| 146 | + } |
| 147 | + ) |
| 148 | + |
| 149 | + def test_describe_data(self, sample_df): |
| 150 | + result = tool_describe_data(sample_df, "target") |
| 151 | + assert "Shape" in result |
| 152 | + assert "target" in result.lower() |
| 153 | + assert "binary" in result.lower() or "2" in result |
| 154 | + |
| 155 | + def test_check_correlations(self, sample_df): |
| 156 | + result = tool_check_correlations(sample_df, "target") |
| 157 | + assert "correlation" in result.lower() or "correlations" in result.lower() |
| 158 | + |
| 159 | + def test_check_leakage(self, sample_df): |
| 160 | + result = tool_check_leakage(sample_df, "target") |
| 161 | + assert "leakage" in result.lower() |
| 162 | + |
| 163 | + def test_create_feature_success(self, sample_df): |
| 164 | + result = tool_create_feature(sample_df, "df['ratio'] = df['num1'] / (df['num2'] + 1)") |
| 165 | + assert result["success"] is True |
| 166 | + assert "ratio" in result["new_columns"] |
| 167 | + |
| 168 | + def test_create_feature_failure(self, sample_df): |
| 169 | + result = tool_create_feature(sample_df, "df['bad'] = df['nonexistent'] * 2") |
| 170 | + assert result["success"] is False |
| 171 | + assert result["error"] is not None |
| 172 | + |
| 173 | + def test_tool_definitions(self): |
| 174 | + assert len(TOOL_DEFINITIONS) >= 5 |
| 175 | + for name, td in TOOL_DEFINITIONS.items(): |
| 176 | + assert td.name == name |
| 177 | + assert len(td.description) > 10 |
| 178 | + assert "type" in td.parameters |
| 179 | + |
| 180 | + def test_serialize_metrics_dict(self): |
| 181 | + m = {"roc_auc": 0.85, "accuracy": 0.9, "report": "text"} |
| 182 | + result = _serialize_metrics(m) |
| 183 | + assert result["roc_auc"] == 0.85 |
| 184 | + |
| 185 | + def test_serialize_metrics_tuple(self): |
| 186 | + result = _serialize_metrics((0.85, 0.02)) |
| 187 | + assert result["oof_mean"] == 0.85 |
| 188 | + |
| 189 | + def test_serialize_metrics_numpy(self): |
| 190 | + m = {"score": np.float64(0.85)} |
| 191 | + result = _serialize_metrics(m) |
| 192 | + assert isinstance(result["score"], float) |
| 193 | + |
| 194 | + |
| 195 | +# --- BlueCastAIResult --- |
| 196 | +class TestBlueCastAIResult: |
| 197 | + def test_repr(self): |
| 198 | + result = BlueCastAIResult(metrics={"auc": 0.9}) |
| 199 | + r = repr(result) |
| 200 | + assert "no pipeline" in r |
| 201 | + assert "auc" in r |
| 202 | + |
| 203 | + def test_repr_with_report(self): |
| 204 | + result = BlueCastAIResult(report_markdown="# Report") |
| 205 | + assert "report=yes" in repr(result) |
| 206 | + |
| 207 | + def test_predict_no_pipeline_raises(self): |
| 208 | + result = BlueCastAIResult() |
| 209 | + with pytest.raises(RuntimeError, match="No trained pipeline"): |
| 210 | + result.predict(pd.DataFrame()) |
| 211 | + |
| 212 | + def test_save_code(self): |
| 213 | + result = BlueCastAIResult( |
| 214 | + pipeline_code="pipeline.fit(df)", |
| 215 | + feature_engineering_code="df['new'] = 1", |
| 216 | + ) |
| 217 | + with tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode="w") as f: |
| 218 | + path = f.name |
| 219 | + try: |
| 220 | + result.save_code(path) |
| 221 | + with open(path) as f: |
| 222 | + code = f.read() |
| 223 | + assert "pipeline.fit" in code |
| 224 | + assert "df['new']" in code |
| 225 | + assert "Auto-generated" in code |
| 226 | + finally: |
| 227 | + os.unlink(path) |
| 228 | + |
| 229 | + def test_save_report(self): |
| 230 | + result = BlueCastAIResult(report_markdown="# My Report\nGreat results.") |
| 231 | + with tempfile.NamedTemporaryFile(suffix=".md", delete=False, mode="w") as f: |
| 232 | + path = f.name |
| 233 | + try: |
| 234 | + result.save_report(path) |
| 235 | + with open(path) as f: |
| 236 | + content = f.read() |
| 237 | + assert "# My Report" in content |
| 238 | + finally: |
| 239 | + os.unlink(path) |
| 240 | + |
| 241 | + def test_save_log(self): |
| 242 | + entries = [ |
| 243 | + AgentLogEntry( |
| 244 | + timestamp=1000.0, agent="A", event_type="task", content="hello" |
| 245 | + ) |
| 246 | + ] |
| 247 | + result = BlueCastAIResult(structured_log=entries) |
| 248 | + with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f: |
| 249 | + path = f.name |
| 250 | + try: |
| 251 | + result.save_log(path) |
| 252 | + with open(path) as f: |
| 253 | + data = json.load(f) |
| 254 | + assert len(data) == 1 |
| 255 | + assert data[0]["agent"] == "A" |
| 256 | + finally: |
| 257 | + os.unlink(path) |
| 258 | + |
| 259 | + def test_show_report_with_markdown(self, capsys): |
| 260 | + result = BlueCastAIResult(report_markdown="# Report Content") |
| 261 | + result.show_report() |
| 262 | + captured = capsys.readouterr() |
| 263 | + assert "# Report Content" in captured.out |
| 264 | + |
| 265 | + def test_show_report_without_markdown(self, capsys): |
| 266 | + result = BlueCastAIResult( |
| 267 | + metrics={"roc_auc": 0.85}, |
| 268 | + run_history=[{"success": True, "metrics": {"roc_auc": 0.85}}], |
| 269 | + class_problem="binary", |
| 270 | + ) |
| 271 | + result.show_report() |
| 272 | + captured = capsys.readouterr() |
| 273 | + assert "binary" in captured.out |
| 274 | + assert "0.85" in captured.out |
0 commit comments