Skip to content

Commit c15f224

Browse files
authored
Add dry run flag for cost estimation (#31)
* Add dry run flag for cost estimation * Fix BatchRun import and update docs for dry run feature * Fix dry run func * Add pdf based text estimation for cost
1 parent 17f67df commit c15f224

File tree

9 files changed

+1569
-744
lines changed

9 files changed

+1569
-744
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ AI providers offer batch APIs that process requests asynchronously at 50% reduce
1616

1717
- Native batch processing (50% cost savings via provider APIs)
1818
- Set `max_cost_usd` limits for batch requests
19+
- **Dry run mode** for cost estimation and job planning
1920
- Time limit control with `.add_time_limit(seconds=, minutes=, hours=)`
2021
- State persistence in case of network interruption
2122
- Structured output `.json` format with Pydantic models
@@ -50,6 +51,9 @@ for file in files:
5051
run = batch.run()
5152

5253
results = run.results() # {"completed": [JobResult], "failed": [JobResult], "cancelled": [JobResult]}
54+
55+
# Or preview costs first with dry run
56+
run = batch.run(dry_run=True) # Shows cost estimates without executing
5357
```
5458

5559
## Complete Example

batchata/core/batch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import BaseModel
1010

1111
from .batch_params import BatchParams
12+
from .batch_run import BatchRun
1213
from .job import Job
1314
from ..providers import get_provider
1415
from ..types import Message
@@ -313,7 +314,7 @@ def add_job(
313314
self.jobs.append(job)
314315
return self
315316

316-
def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False) -> 'BatchRun':
317+
def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun':
317318
"""Execute the batch.
318319
319320
Creates a BatchRun instance and executes the jobs synchronously.
@@ -323,6 +324,7 @@ def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None,
323324
(stats_dict, elapsed_time_seconds, batch_data)
324325
progress_interval: Interval in seconds between progress updates (default: 1.0)
325326
print_status: Whether to show rich progress display (default: False)
327+
dry_run: If True, only show cost estimation without executing (default: False)
326328
327329
Returns:
328330
BatchRun instance with completed results
@@ -339,6 +341,10 @@ def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None,
339341
# Create and start the run
340342
run = BatchRun(self.config, self.jobs)
341343

344+
# Handle dry run mode
345+
if dry_run:
346+
return run.dry_run()
347+
342348
# Set progress callback - either rich display or custom callback
343349
if print_status:
344350
return self._run_with_rich_display(run, progress_interval)

batchata/core/batch_run.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,4 +826,81 @@ def _create_cancelled_results(self) -> List[JobResult]:
826826

827827
def shutdown(self):
828828
"""Shutdown (no-op for synchronous execution)."""
829-
pass
829+
pass
830+
831+
def dry_run(self) -> 'BatchRun':
832+
"""Perform a dry run - show cost estimation and job details without executing.
833+
834+
Returns:
835+
Self for chaining (doesn't actually execute jobs)
836+
"""
837+
logger.info("=== DRY RUN MODE ===")
838+
logger.info("This will show cost estimates without executing jobs")
839+
840+
# Load existing state if reuse_state=True
841+
if self.config.reuse_state:
842+
self.state_manager.load_state(self)
843+
844+
# Filter out completed jobs from previous runs
845+
self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results]
846+
847+
if not self.pending_jobs:
848+
logger.info("No pending jobs to analyze (all jobs already completed)")
849+
return self
850+
851+
logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...")
852+
853+
# Group jobs by provider and analyze costs
854+
provider_groups = self._group_jobs_by_provider()
855+
total_estimated_cost = 0.0
856+
857+
logger.info(f"\nJob breakdown:")
858+
for provider_name, jobs in provider_groups.items():
859+
provider = get_provider(jobs[0].model)
860+
logger.info(f"\n{provider_name} ({len(jobs)} jobs):")
861+
862+
job_batches = [jobs[i:i + self.config.items_per_batch]
863+
for i in range(0, len(jobs), self.config.items_per_batch)]
864+
865+
for batch_idx, batch_jobs in enumerate(job_batches, 1):
866+
estimated_cost = provider.estimate_cost(batch_jobs)
867+
total_estimated_cost += estimated_cost
868+
869+
logger.info(f" Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}")
870+
for job in batch_jobs:
871+
if job.file:
872+
logger.info(f" - {job.id}: {job.file.name} (citations: {job.enable_citations})")
873+
else:
874+
logger.info(f" - {job.id}: direct messages (citations: {job.enable_citations})")
875+
876+
# Show cost summary
877+
logger.info(f"\n=== COST SUMMARY ===")
878+
logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}")
879+
880+
if self.config.cost_limit_usd:
881+
logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}")
882+
if total_estimated_cost > self.config.cost_limit_usd:
883+
excess = total_estimated_cost - self.config.cost_limit_usd
884+
logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}")
885+
else:
886+
remaining = self.config.cost_limit_usd - total_estimated_cost
887+
logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)")
888+
else:
889+
logger.info("No cost limit set")
890+
891+
# Show execution plan
892+
logger.info(f"\n=== EXECUTION PLAN ===")
893+
total_batches = sum(
894+
len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0)
895+
for jobs in provider_groups.values()
896+
)
897+
logger.info(f"Total batches to process: {total_batches}")
898+
logger.info(f"Max parallel batches: {self.config.max_parallel_batches}")
899+
logger.info(f"Items per batch: {self.config.items_per_batch}")
900+
logger.info(f"Results directory: {self.config.results_dir}")
901+
902+
logger.info("\n=== DRY RUN COMPLETE ===")
903+
logger.info("To execute for real, call run() without dry_run=True")
904+
905+
return self
906+

batchata/providers/anthropic/anthropic.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,30 @@ def estimate_cost(self, jobs: List[Job]) -> float:
238238
if system_prompt:
239239
full_text += system_prompt + "\n\n"
240240

241-
for msg in messages:
242-
role = msg.get("role", "")
243-
content = msg.get("content", "")
244-
full_text += f"{role}: {content}\n\n"
245-
246-
# Estimate tokens using Claude-specific estimator
247-
input_tokens = token_count_simple(full_text)
241+
# Handle PDF files specially
242+
if job.file and job.file.suffix.lower() == '.pdf':
243+
from ...utils.pdf import estimate_pdf_tokens
244+
input_tokens = estimate_pdf_tokens(job.file, job.prompt)
245+
logger.debug(f"Job {job.id}: Estimated {input_tokens} tokens for PDF")
246+
else:
247+
# Normal message handling
248+
for msg in messages:
249+
role = msg.get("role", "")
250+
content = msg.get("content", "")
251+
# Handle content that might be a list (for multimodal messages)
252+
if isinstance(content, list):
253+
for part in content:
254+
if isinstance(part, dict) and part.get("type") == "text":
255+
full_text += f"{role}: {part.get('text', '')}\n\n"
256+
else:
257+
full_text += f"{role}: {content}\n\n"
258+
259+
# Add prompt if it's a file-based job
260+
if job.prompt:
261+
full_text += f"\nUser prompt: {job.prompt}\n"
262+
263+
# Estimate tokens using Claude-specific estimator
264+
input_tokens = token_count_simple(full_text)
248265

249266
# Calculate costs using tokencost with actual Claude model
250267
input_cost = float(calculate_cost_by_tokens(
@@ -264,7 +281,7 @@ def estimate_cost(self, jobs: List[Job]) -> float:
264281
discount = model_config.batch_discount if model_config else 0.5
265282
job_cost = (input_cost + output_cost) * discount
266283

267-
logger.info(
284+
logger.debug(
268285
f"Job {job.id}: ~{input_tokens} input tokens, "
269286
f"{job.max_tokens} max output tokens, "
270287
f"cost: ${job_cost:.6f} (with {int(discount*100)}% batch discount)"
@@ -276,4 +293,5 @@ def estimate_cost(self, jobs: List[Job]) -> float:
276293
logger.warning(f"Failed to estimate cost for job {job.id}: {e}")
277294
continue
278295

279-
return total_cost
296+
return total_cost
297+

batchata/utils/pdf.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
"""
22
PDF Utilities Module
33
4-
Provides utility functions for creating test PDFs.
4+
Provides utility functions for creating test PDFs and extracting text.
55
"""
66

77
import re
88
from pathlib import Path
9-
from typing import List
9+
from typing import List, Tuple, Optional
1010

1111
import pypdf
12+
from ..utils import get_logger
13+
14+
logger = get_logger(__name__)
1215

1316

1417
def create_pdf(pages: List[str]) -> bytes:
@@ -136,4 +139,117 @@ def is_textual_pdf(
136139

137140
except Exception:
138141
# If PDF can't be read, assume it's not textual
139-
return 0.0
142+
return 0.0
143+
144+
145+
def extract_text_from_pdf(path: str | Path) -> str:
146+
"""
147+
Extract all text from a PDF file.
148+
149+
Args:
150+
path: Path to the PDF file
151+
152+
Returns:
153+
str: Extracted text from all pages
154+
"""
155+
try:
156+
reader = pypdf.PdfReader(str(path))
157+
text_parts = []
158+
159+
for page_num, page in enumerate(reader.pages):
160+
try:
161+
text = page.extract_text()
162+
if text.strip():
163+
text_parts.append(text)
164+
except Exception as e:
165+
logger.debug(f"Failed to extract text from page {page_num}: {e}")
166+
continue
167+
168+
return "\n\n".join(text_parts)
169+
170+
except Exception as e:
171+
logger.warning(f"Failed to extract text from PDF {path}: {e}")
172+
return ""
173+
174+
175+
def get_pdf_info(path: str | Path) -> Tuple[int, bool, Optional[str]]:
176+
"""
177+
Get information about a PDF file for cost estimation.
178+
179+
Args:
180+
path: Path to the PDF file
181+
182+
Returns:
183+
Tuple of (page_count, is_textual, extracted_text)
184+
- page_count: Number of pages in the PDF
185+
- is_textual: Whether the PDF has extractable text
186+
- extracted_text: Text content if textual, None otherwise
187+
"""
188+
try:
189+
reader = pypdf.PdfReader(str(path))
190+
page_count = len(reader.pages)
191+
192+
# Check if PDF is textual
193+
textual_score = is_textual_pdf(path)
194+
is_textual = textual_score > 0.5 # Consider textual if >50% pages have text
195+
196+
# Extract text if textual
197+
extracted_text = None
198+
if is_textual:
199+
extracted_text = extract_text_from_pdf(path)
200+
if not extracted_text.strip():
201+
is_textual = False
202+
extracted_text = None
203+
204+
logger.debug(f"PDF info for {path}: {page_count} pages, textual={is_textual}, "
205+
f"text_length={len(extracted_text) if extracted_text else 0}")
206+
207+
return page_count, is_textual, extracted_text
208+
209+
except Exception as e:
210+
logger.error(f"Failed to get PDF info for {path}: {e}")
211+
return 0, False, None
212+
213+
214+
def estimate_pdf_tokens(path: str | Path, prompt: Optional[str] = None,
215+
pdf_token_multiplier: float = 1.5,
216+
tokens_per_page: int = 2000) -> int:
217+
"""
218+
Estimate token count for a PDF file.
219+
220+
This is a generic utility that can be used by any provider to estimate
221+
tokens for PDF processing.
222+
223+
Args:
224+
path: Path to the PDF file
225+
prompt: Optional prompt to include in token count
226+
pdf_token_multiplier: Coefficient to apply to extracted text tokens
227+
to account for PDF processing overhead (default: 1.5)
228+
tokens_per_page: Estimated tokens per page for image-based PDFs (default: 2000)
229+
230+
Returns:
231+
Estimated token count
232+
"""
233+
from .llm import token_count_simple
234+
235+
page_count, is_textual, extracted_text = get_pdf_info(path)
236+
237+
if is_textual and extracted_text:
238+
# Count tokens from extracted text
239+
base_tokens = token_count_simple(extracted_text)
240+
if prompt:
241+
base_tokens += token_count_simple(prompt)
242+
243+
# Apply multiplier to account for PDF processing overhead
244+
input_tokens = int(base_tokens * pdf_token_multiplier)
245+
logger.debug(f"Textual PDF {path}: {page_count} pages, "
246+
f"base tokens: {base_tokens}, with {pdf_token_multiplier}x multiplier: {input_tokens}")
247+
else:
248+
# Estimate based on page count
249+
input_tokens = page_count * tokens_per_page
250+
if prompt:
251+
input_tokens += token_count_simple(prompt)
252+
logger.debug(f"Image-based PDF {path}: {page_count} pages, "
253+
f"estimated tokens: {input_tokens} ({tokens_per_page} per page)")
254+
255+
return input_tokens

0 commit comments

Comments
 (0)