Skip to content

Commit 4d48925

Browse files
authored
Move file content to be a reference in state file (#27)
* Move file content to be a reference in state file * Clear pending when done in state file * Fix tests locally * Fix warning -> debug
1 parent a6c582d commit 4d48925

File tree

7 files changed

+78
-14
lines changed

7 files changed

+78
-14
lines changed

batchata/core/batch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Batch builder."""
22

3+
import logging
34
import uuid
45
from datetime import datetime
56
from pathlib import Path
@@ -281,6 +282,14 @@ def add_job(
281282
if isinstance(file, str):
282283
file = Path(file)
283284

285+
# Warn about temporary file paths that may not persist
286+
if file:
287+
file_str = str(file)
288+
if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower():
289+
logger = logging.getLogger("batchata")
290+
logger.debug(f"File path appears to be in a temporary directory: {file}")
291+
logger.debug("This may cause issues when resuming from state if temp files are cleaned up")
292+
284293
# Create job
285294
job = Job(
286295
id=job_id,

batchata/core/batch_run.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,27 @@ def _resume_from_state(self):
132132
self.pending_jobs = []
133133
for job_data in state.pending_jobs:
134134
job = Job.from_dict(job_data)
135-
self.pending_jobs.append(job)
135+
# Check if file exists (if job has a file)
136+
if job.file and not job.file.exists():
137+
logger.error(f"File not found for job {job.id}: {job.file}")
138+
logger.error("This may happen if files were in temporary directories that were cleaned up")
139+
self.failed_jobs[job.id] = f"File not found: {job.file}"
140+
else:
141+
self.pending_jobs.append(job)
136142

137-
# Restore completed results
138-
for result_data in state.completed_results:
139-
result = JobResult.from_dict(result_data)
140-
self.completed_results[result.job_id] = result
143+
# Restore completed results from file references
144+
for result_ref in state.completed_results:
145+
job_id = result_ref["job_id"]
146+
file_path = result_ref["file_path"]
147+
try:
148+
with open(file_path, 'r') as f:
149+
result_data = json.load(f)
150+
result = JobResult.from_dict(result_data)
151+
self.completed_results[job_id] = result
152+
except Exception as e:
153+
logger.error(f"Failed to load result for {job_id} from {file_path}: {e}")
154+
# Move to failed jobs if we can't load the result
155+
self.failed_jobs[job_id] = f"Failed to load result file: {e}"
141156

142157
# Restore failed jobs
143158
for job_data in state.failed_jobs:
@@ -162,7 +177,10 @@ def to_json(self) -> Dict:
162177
return {
163178
"created_at": datetime.now().isoformat(),
164179
"pending_jobs": [job.to_dict() for job in self.pending_jobs],
165-
"completed_results": [result.to_dict() for result in self.completed_results.values()],
180+
"completed_results": [
181+
{"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")}
182+
for job_id in self.completed_results.keys()
183+
],
166184
"failed_jobs": [
167185
{
168186
"id": job_id,
@@ -505,10 +523,15 @@ def _update_batch_results(self, batch_result: Dict):
505523
self.failed_jobs[result.job_id] = error_message
506524
self._save_result_to_file(result)
507525
logger.error(f"✗ Job {result.job_id} failed: {result.error}")
526+
527+
# Remove completed/failed job from pending
528+
self.pending_jobs = [job for job in self.pending_jobs if job.id != result.job_id]
508529

509530
# Update failed jobs
510531
for job_id, error in failed.items():
511532
self.failed_jobs[job_id] = error
533+
# Remove failed job from pending
534+
self.pending_jobs = [job for job in self.pending_jobs if job.id != job_id]
512535
logger.error(f"✗ Job {job_id} failed: {error}")
513536

514537
# Update batch tracking

batchata/utils/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class BatchState:
2222
created_at: When the batch was created
2323
pending_jobs: Jobs that haven't been submitted yet
2424
active_batches: Currently running batch IDs
25-
completed_results: Results from completed jobs
25+
completed_results: File references to completed job results
2626
failed_jobs: Jobs that failed with errors
2727
total_cost_usd: Total cost incurred so far
2828
config: Original batch configuration
@@ -31,7 +31,7 @@ class BatchState:
3131
created_at: str # ISO format datetime
3232
pending_jobs: List[Dict[str, Any]] # Serialized Job objects
3333
active_batches: List[str] # Provider batch IDs
34-
completed_results: List[Dict[str, Any]] # Serialized JobResult objects
34+
completed_results: List[Dict[str, str]] # File references: [{"job_id": "job_123", "file_path": "/path/to/result.json"}]
3535
failed_jobs: List[Dict[str, Any]] # Jobs with error info
3636
total_cost_usd: float
3737
config: Dict[str, Any] # Batch configuration

tests/core/test_batch_validation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tests for batch validation with citations."""
22

33
import pytest
4+
import os
45
from pydantic import BaseModel
56
from typing import Optional
7+
from unittest.mock import patch
68

79
from batchata import Batch
810

@@ -29,6 +31,15 @@ class NestedInvoice(BaseModel):
2931
class TestBatchCitationValidation:
3032
"""Test early validation of citation compatibility."""
3133

34+
@pytest.fixture(autouse=True)
35+
def mock_api_keys(self):
36+
"""Provide mock API keys for provider initialization."""
37+
with patch.dict(os.environ, {
38+
'ANTHROPIC_API_KEY': 'test-anthropic-key',
39+
'OPENAI_API_KEY': 'test-openai-key'
40+
}):
41+
yield
42+
3243
def test_flat_model_with_citations_allowed(self):
3344
"""Test that flat models work with citations."""
3445
batch = Batch("./results").set_state(file="./state")

tests/core/test_pdf_validation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tests for PDF validation in batch jobs."""
22

33
import pytest
4+
import os
45
import tempfile
56
from pathlib import Path
7+
from unittest.mock import patch
68

79
from batchata.core.batch import Batch
810
from batchata.utils.pdf import create_pdf
@@ -12,14 +14,23 @@
1214
class TestPdfValidation:
1315
"""Test PDF validation when citations are enabled."""
1416

17+
@pytest.fixture(autouse=True)
18+
def mock_api_keys(self):
19+
"""Provide mock API keys for provider initialization."""
20+
with patch.dict(os.environ, {
21+
'ANTHROPIC_API_KEY': 'test-anthropic-key',
22+
'OPENAI_API_KEY': 'test-openai-key'
23+
}):
24+
yield
25+
1526
def test_image_pdf_with_citations_fails(self):
1627
"""Test image PDF with citations should fail."""
1728
with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp:
1829
tmp.write(b"not a real pdf") # Score 0.0
1930
tmp.flush()
2031

2132
batch = Batch("/tmp/results").set_state(file="/tmp/state.json")
22-
batch.set_default_params(model="claude-3-5-sonnet-latest")
33+
batch.set_default_params(model="claude-3-5-sonnet-20241022")
2334

2435
with pytest.raises(ValidationError, match="appears to be image-only"):
2536
batch.add_job(file=tmp.name, prompt="Test", enable_citations=True)
@@ -33,7 +44,7 @@ def test_image_pdf_without_citations_works(self):
3344
tmp.flush()
3445

3546
batch = Batch("/tmp/results").set_state(file="/tmp/state.json")
36-
batch.set_default_params(model="claude-3-5-sonnet-latest")
47+
batch.set_default_params(model="claude-3-5-sonnet-20241022")
3748
batch.add_job(file=tmp.name, prompt="Test", enable_citations=False)
3849

3950
assert len(batch.jobs) == 1
@@ -49,7 +60,7 @@ def test_textual_pdf_with_citations_works(self):
4960
tmp.flush()
5061

5162
batch = Batch("/tmp/results").set_state(file="/tmp/state.json")
52-
batch.set_default_params(model="claude-3-5-sonnet-latest")
63+
batch.set_default_params(model="claude-3-5-sonnet-20241022")
5364
batch.add_job(file=tmp.name, prompt="Test", enable_citations=True)
5465

5566
assert len(batch.jobs) == 1
@@ -62,7 +73,7 @@ def test_non_pdf_bypasses_validation(self):
6273
tmp.flush()
6374

6475
batch = Batch("/tmp/results").set_state(file="/tmp/state.json")
65-
batch.set_default_params(model="claude-3-5-sonnet-latest")
76+
batch.set_default_params(model="claude-3-5-sonnet-20241022")
6677
batch.add_job(file=tmp.name, prompt="Test", enable_citations=True)
6778

6879
assert len(batch.jobs) == 1

tests/providers/test_provider_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import pytest
10+
import os
1011
from unittest.mock import patch, MagicMock
1112

1213
from batchata.providers import get_provider
@@ -18,6 +19,15 @@
1819
class TestProviderRegistry:
1920
"""Test provider registry functionality."""
2021

22+
@pytest.fixture(autouse=True)
23+
def mock_api_keys(self):
24+
"""Provide mock API keys for provider initialization."""
25+
with patch.dict(os.environ, {
26+
'ANTHROPIC_API_KEY': 'test-anthropic-key',
27+
'OPENAI_API_KEY': 'test-openai-key'
28+
}):
29+
yield
30+
2131
def test_provider_lookup_by_model(self):
2232
"""Test looking up providers by model name."""
2333
# Create a mock provider and register it

tests/utils/test_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ def test_save_and_load_state(self, temp_dir):
3535
],
3636
active_batches=["batch-123"],
3737
completed_results=[
38-
{"job_id": "job-0", "content": "A0", "cost": 0.01,
39-
"input_tokens": 10, "output_tokens": 20}
38+
{"job_id": "job-0", "file_path": "/tmp/job-0.json"}
4039
],
4140
failed_jobs=[],
4241
total_cost_usd=0.01,
@@ -57,6 +56,7 @@ def test_save_and_load_state(self, temp_dir):
5756
assert loaded.active_batches[0] == "batch-123"
5857
assert len(loaded.completed_results) == 1
5958
assert loaded.completed_results[0]["job_id"] == "job-0"
59+
assert loaded.completed_results[0]["file_path"] == "/tmp/job-0.json"
6060
assert loaded.total_cost_usd == 0.01
6161
assert loaded.config["max_parallel_batches"] == 10
6262

0 commit comments

Comments
 (0)