Skip to content

Commit 5aaf259

Browse files
committed
fix: update test to use _build_record instead of _generate_sample for consistency and correctness in record creation during tests
1 parent 4a3734d commit 5aaf259

File tree

1 file changed

+51
-49
lines changed

1 file changed

+51
-49
lines changed

tests/test_generator_integration.py

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,30 @@
99
import pytest
1010

1111
from toolsgen.core.config import GenerationConfig, ModelConfig, RoleBasedModelConfig
12-
from toolsgen.core.generator import _generate_sample, generate_dataset
12+
from toolsgen.core.generator import generate_dataset
13+
from toolsgen.core.record_builder import _build_record
1314
from toolsgen.schema import AssistantToolCall, ToolFunction, ToolSpec
1415

1516

16-
@patch("toolsgen.core.generator.judge_tool_calls")
17-
@patch("toolsgen.core.generator.generate_tool_calls")
18-
@patch("toolsgen.core.generator.generate_problem")
17+
@patch("toolsgen.core.record_builder.judge_tool_calls")
18+
@patch("toolsgen.core.record_builder.generate_tool_calls")
19+
@patch("toolsgen.core.record_builder.generate_problem")
1920
def test_generate_sample_success(
2021
mock_problem: MagicMock, mock_tool_calls: MagicMock, mock_judge: MagicMock
2122
) -> None:
2223
"""Test successful sample generation."""
23-
# Mock problem generation
2424
mock_problem.return_value = "Send an email to user@example.com"
2525

26-
# Mock tool call generation
2726
mock_tool_call = AssistantToolCall(
2827
id="call_1",
2928
function={"name": "send_email", "arguments": '{"to": "user@example.com"}'},
3029
)
3130
mock_tool_calls.return_value = [mock_tool_call]
3231

33-
# Mock judge
3432
mock_judge_result = MagicMock()
3533
mock_judge_result.to_dict.return_value = {"score": 0.9, "verdict": "accept"}
3634
mock_judge.return_value = mock_judge_result
3735

38-
# Create mock clients
3936
problem_client = MagicMock()
4037
caller_client = MagicMock()
4138
judge_client = MagicMock()
@@ -46,7 +43,7 @@ def test_generate_sample_success(
4643

4744
role_config = RoleBasedModelConfig.from_single_config(ModelConfig(model="gpt-4"))
4845

49-
record = _generate_sample(
46+
record = _build_record(
5047
problem_client,
5148
caller_client,
5249
judge_client,
@@ -65,7 +62,7 @@ def test_generate_sample_success(
6562
assert record.judge["score"] == 0.9
6663

6764

68-
@patch("toolsgen.core.generator.generate_problem")
65+
@patch("toolsgen.core.record_builder.generate_problem")
6966
def test_generate_sample_problem_fails(mock_problem: MagicMock) -> None:
7067
"""Test sample generation when problem generation fails."""
7168
mock_problem.return_value = None
@@ -77,15 +74,21 @@ def test_generate_sample_problem_fails(mock_problem: MagicMock) -> None:
7774
tools = [ToolSpec(function=ToolFunction(name="test"))]
7875
role_config = RoleBasedModelConfig.from_single_config(ModelConfig(model="gpt-4"))
7976

80-
record = _generate_sample(
81-
problem_client, caller_client, judge_client, "rec_001", tools, role_config
77+
record = _build_record(
78+
problem_client,
79+
caller_client,
80+
judge_client,
81+
"rec_001",
82+
tools,
83+
role_config,
84+
"english",
8285
)
8386

8487
assert record is None
8588

8689

87-
@patch("toolsgen.core.generator.generate_tool_calls")
88-
@patch("toolsgen.core.generator.generate_problem")
90+
@patch("toolsgen.core.record_builder.generate_tool_calls")
91+
@patch("toolsgen.core.record_builder.generate_problem")
8992
def test_generate_sample_tool_calls_fail(
9093
mock_problem: MagicMock, mock_tool_calls: MagicMock
9194
) -> None:
@@ -100,16 +103,22 @@ def test_generate_sample_tool_calls_fail(
100103
tools = [ToolSpec(function=ToolFunction(name="test"))]
101104
role_config = RoleBasedModelConfig.from_single_config(ModelConfig(model="gpt-4"))
102105

103-
record = _generate_sample(
104-
problem_client, caller_client, judge_client, "rec_001", tools, role_config
106+
record = _build_record(
107+
problem_client,
108+
caller_client,
109+
judge_client,
110+
"rec_001",
111+
tools,
112+
role_config,
113+
"english",
105114
)
106115

107116
assert record is None
108117

109118

110-
@patch("toolsgen.core.generator.judge_tool_calls")
111-
@patch("toolsgen.core.generator.generate_tool_calls")
112-
@patch("toolsgen.core.generator.generate_problem")
119+
@patch("toolsgen.core.record_builder.judge_tool_calls")
120+
@patch("toolsgen.core.record_builder.generate_tool_calls")
121+
@patch("toolsgen.core.record_builder.generate_problem")
113122
def test_generate_sample_judge_fails(
114123
mock_problem: MagicMock, mock_tool_calls: MagicMock, mock_judge: MagicMock
115124
) -> None:
@@ -128,17 +137,22 @@ def test_generate_sample_judge_fails(
128137
tools = [ToolSpec(function=ToolFunction(name="test"))]
129138
role_config = RoleBasedModelConfig.from_single_config(ModelConfig(model="gpt-4"))
130139

131-
record = _generate_sample(
132-
problem_client, caller_client, judge_client, "rec_001", tools, role_config
140+
record = _build_record(
141+
problem_client,
142+
caller_client,
143+
judge_client,
144+
"rec_001",
145+
tools,
146+
role_config,
147+
"english",
133148
)
134149

135-
# Should still return record even if judge fails
136150
assert record is not None
137151
assert record.id == "rec_001"
138152

139153

140-
@patch("toolsgen.core.generator._generate_sample")
141-
@patch("toolsgen.core.generator.create_openai_client")
154+
@patch("toolsgen.core.sequential.RecordBuilder.generate_record")
155+
@patch("toolsgen.core.client.create_openai_client")
142156
def test_generate_dataset_basic(
143157
mock_create_client: MagicMock,
144158
mock_generate_sample: MagicMock,
@@ -148,7 +162,6 @@ def test_generate_dataset_basic(
148162
"""Test basic dataset generation."""
149163
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
150164

151-
# Create tools file
152165
tools_path = tmp_path / "tools.json"
153166
tools_data = [
154167
{
@@ -162,11 +175,9 @@ def test_generate_dataset_basic(
162175
]
163176
tools_path.write_text(json.dumps(tools_data), encoding="utf-8")
164177

165-
# Mock client creation
166178
mock_client = MagicMock()
167179
mock_create_client.return_value = mock_client
168180

169-
# Mock sample generation
170181
mock_record = MagicMock()
171182
mock_record.id = "rec_000000"
172183
mock_record.model_dump.return_value = {"id": "rec_000000"}
@@ -180,19 +191,17 @@ def test_generate_dataset_basic(
180191
output_dir, gen_config, model_config, tools_path=tools_path
181192
)
182193

183-
# Verify manifest
184194
assert manifest["num_requested"] == 3
185195
assert manifest["num_generated"] == 3
186196
assert manifest["strategy"] == "random"
187197
assert manifest["seed"] == 42
188198

189-
# Verify files created
190199
assert (output_dir / "manifest.json").exists()
191200
assert (output_dir / "train.jsonl").exists()
192201

193202

194-
@patch("toolsgen.core.generator._generate_sample")
195-
@patch("toolsgen.core.generator.create_openai_client")
203+
@patch("toolsgen.core.sequential.RecordBuilder.generate_record")
204+
@patch("toolsgen.core.client.create_openai_client")
196205
def test_generate_dataset_with_splits(
197206
mock_create_client: MagicMock,
198207
mock_generate_sample: MagicMock,
@@ -210,7 +219,6 @@ def test_generate_dataset_with_splits(
210219
mock_client = MagicMock()
211220
mock_create_client.return_value = mock_client
212221

213-
# Generate 10 records
214222
def create_mock_record(call_count: list[int] = [0]) -> MagicMock:
215223
record = MagicMock()
216224
record.id = f"rec_{call_count[0]:06d}"
@@ -239,13 +247,12 @@ def side_effect(*args: object, **kwargs: object) -> MagicMock:
239247
assert manifest["splits"]["train"] == 8
240248
assert manifest["splits"]["val"] == 2
241249

242-
# Verify split files
243250
assert (output_dir / "train.jsonl").exists()
244251
assert (output_dir / "val.jsonl").exists()
245252

246253

247-
@patch("toolsgen.core.generator._generate_sample")
248-
@patch("toolsgen.core.generator.create_openai_client")
254+
@patch("toolsgen.core.sequential.RecordBuilder.generate_record")
255+
@patch("toolsgen.core.client.create_openai_client")
249256
def test_generate_dataset_with_failures(
250257
mock_create_client: MagicMock,
251258
mock_generate_sample: MagicMock,
@@ -263,14 +270,13 @@ def test_generate_dataset_with_failures(
263270
mock_client = MagicMock()
264271
mock_create_client.return_value = mock_client
265272

266-
# Make first attempt fail, second succeed
267273
call_count = [0]
268274

269275
def mock_sample_gen(*args: object, **kwargs: object) -> MagicMock | None:
270276
call_count[0] += 1
271-
if call_count[0] % 2 == 1: # Odd calls fail
277+
if call_count[0] % 2 == 1:
272278
return None
273-
else: # Even calls succeed
279+
else:
274280
record = MagicMock()
275281
record.id = f"rec_{(call_count[0] // 2) - 1:06d}"
276282
record.model_dump.return_value = {"id": record.id}
@@ -288,10 +294,10 @@ def mock_sample_gen(*args: object, **kwargs: object) -> MagicMock | None:
288294

289295
assert manifest["num_requested"] == 3
290296
assert manifest["num_generated"] == 3
291-
assert manifest["num_failed"] >= 0 # Some attempts failed
297+
assert manifest["num_failed"] >= 0
292298

293299

294-
@patch("toolsgen.core.generator.create_openai_client")
300+
@patch("toolsgen.core.client.create_openai_client")
295301
def test_generate_dataset_role_based_config(
296302
mock_create_client: MagicMock, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
297303
) -> None:
@@ -315,8 +321,7 @@ def test_generate_dataset_role_based_config(
315321
gen_config = GenerationConfig(num_samples=1)
316322
output_dir = tmp_path / "output"
317323

318-
# This will fail during generation but we just want to test config handling
319-
with patch("toolsgen.core.generator._generate_sample") as mock_gen:
324+
with patch("toolsgen.core.sequential.RecordBuilder.generate_record") as mock_gen:
320325
mock_record = MagicMock()
321326
mock_record.id = "rec_000000"
322327
mock_record.model_dump.return_value = {"id": "rec_000000"}
@@ -326,14 +331,13 @@ def test_generate_dataset_role_based_config(
326331
output_dir, gen_config, role_config, tools_path=tools_path
327332
)
328333

329-
# Verify role-based models in manifest
330334
assert manifest["models"]["problem_generator"] == "gpt-4"
331335
assert manifest["models"]["tool_caller"] == "gpt-4o"
332336
assert manifest["models"]["judge"] == "gpt-4o-mini"
333337

334338

335-
@patch("toolsgen.core.generator._generate_sample")
336-
@patch("toolsgen.core.generator.create_openai_client")
339+
@patch("toolsgen.core.sequential.RecordBuilder.generate_record")
340+
@patch("toolsgen.core.client.create_openai_client")
337341
def test_generate_dataset_param_aware_strategy(
338342
mock_create_client: MagicMock,
339343
mock_generate_sample: MagicMock,
@@ -388,8 +392,8 @@ def test_generate_dataset_param_aware_strategy(
388392
assert manifest["num_generated"] == 2
389393

390394

391-
@patch("toolsgen.core.generator._generate_sample")
392-
@patch("toolsgen.core.generator.create_openai_client")
395+
@patch("toolsgen.core.sequential.RecordBuilder.generate_record")
396+
@patch("toolsgen.core.client.create_openai_client")
393397
def test_generate_dataset_with_tools_list(
394398
mock_create_client: MagicMock,
395399
mock_generate_sample: MagicMock,
@@ -399,7 +403,6 @@ def test_generate_dataset_with_tools_list(
399403
"""Test dataset generation with direct tools list instead of path."""
400404
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
401405

402-
# Create tools list directly
403406
tools = [
404407
ToolSpec(
405408
function=ToolFunction(
@@ -422,7 +425,6 @@ def test_generate_dataset_with_tools_list(
422425
model_config = ModelConfig(model="gpt-4")
423426
output_dir = tmp_path / "output"
424427

425-
# Call with tools list instead of tools_path
426428
manifest = generate_dataset(output_dir, gen_config, model_config, tools=tools)
427429

428430
assert manifest["num_requested"] == 2

0 commit comments

Comments
 (0)