Skip to content

Commit 0971812

Browse files
authored
Fix progress bar threading, status display, and PDF token estimation (#32)
* Fix redundnat logger info * Fix progress bar threading * Fix status bar * Fix the token estimation for PDF files * Fix locking issue and better rich status refactor
1 parent 4d01564 commit 0971812

File tree

9 files changed

+1523
-1188
lines changed

9 files changed

+1523
-1188
lines changed

batchata/core/batch_run.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(self, config: BatchParams, jobs: List[Job]):
8383
# Threading primitives
8484
self._state_lock = threading.Lock()
8585
self._shutdown_event = threading.Event()
86+
self._progress_lock = threading.Lock()
87+
self._last_progress_update = 0.0
8688

8789
# Batch tracking for progress display
8890
self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info
@@ -231,9 +233,12 @@ def signal_handler(signum, frame):
231233

232234
# Call initial progress
233235
if self._progress_callback:
234-
stats = self.status()
235-
batch_data = dict(self.batch_tracking)
236-
self._progress_callback(stats, 0.0, batch_data)
236+
with self._progress_lock:
237+
with self._state_lock:
238+
stats = self.status()
239+
batch_data = dict(self.batch_tracking)
240+
self._progress_callback(stats, 0.0, batch_data)
241+
self._last_progress_update = time.time()
237242

238243
# Process all jobs synchronously
239244
self._process_all_jobs()
@@ -495,14 +500,22 @@ def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dic
495500
status, error_details = provider.get_batch_status(batch_id)
496501

497502
if self._progress_callback:
498-
with self._state_lock:
499-
stats = self.status()
500-
elapsed_time = (datetime.now() - self._start_time).total_seconds()
501-
batch_data = dict(self.batch_tracking)
502-
self._progress_callback(stats, elapsed_time, batch_data)
503+
# Rate limit progress updates and synchronize calls to prevent duplicate printing
504+
current_time = time.time()
505+
should_update = current_time - self._last_progress_update >= self._progress_interval
506+
507+
if should_update:
508+
with self._progress_lock:
509+
# Double-check timing inside the lock to avoid race condition
510+
if current_time - self._last_progress_update >= self._progress_interval:
511+
with self._state_lock:
512+
stats = self.status()
513+
elapsed_time = (datetime.now() - self._start_time).total_seconds()
514+
batch_data = dict(self.batch_tracking)
515+
self._progress_callback(stats, elapsed_time, batch_data)
516+
self._last_progress_update = current_time
503517

504518
elapsed_seconds = poll_count * provider_polling_interval
505-
logger.info(f"Batch {batch_id} status: {status} (polling for {elapsed_seconds:.1f}s)")
506519

507520
return status, error_details
508521

batchata/providers/openai/openai_provider.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -272,29 +272,36 @@ def estimate_cost(self, jobs: List[Job]) -> float:
272272

273273
for job in jobs:
274274
try:
275-
# Prepare messages to get actual input
276-
from .message_prepare import prepare_messages
277-
messages, response_format = prepare_messages(job)
278-
279-
# Build full text for token estimation
280-
full_text = ""
281-
for msg in messages:
282-
role = msg.get("role", "")
283-
content = msg.get("content", "")
284-
if isinstance(content, list):
285-
# Handle multipart content (images, etc.)
286-
for part in content:
287-
if part.get("type") == "text":
288-
full_text += f"{role}: {part.get('text', '')}\\n\\n"
289-
else:
290-
full_text += f"{role}: {content}\\n\\n"
291-
292-
# Add response format to token count if structured output
293-
if response_format:
294-
full_text += json.dumps(response_format)
295-
296-
# Estimate tokens
297-
input_tokens = token_count_simple(full_text)
275+
# Handle PDF files specially with accurate token estimation
276+
if job.file and job.file.suffix.lower() == '.pdf':
277+
from ...utils.pdf import estimate_pdf_tokens
278+
# OpenAI: 300-1,280 tokens/page, use 1000 as reasonable average
279+
input_tokens = estimate_pdf_tokens(job.file, job.prompt, tokens_per_page=1000)
280+
logger.debug(f"Job {job.id}: Estimated {input_tokens} tokens for PDF")
281+
else:
282+
# Prepare messages to get actual input
283+
from .message_prepare import prepare_messages
284+
messages, response_format = prepare_messages(job)
285+
286+
# Build full text for token estimation
287+
full_text = ""
288+
for msg in messages:
289+
role = msg.get("role", "")
290+
content = msg.get("content", "")
291+
if isinstance(content, list):
292+
# Handle multipart content (images, etc.)
293+
for part in content:
294+
if part.get("type") == "text":
295+
full_text += f"{role}: {part.get('text', '')}\\n\\n"
296+
else:
297+
full_text += f"{role}: {content}\\n\\n"
298+
299+
# Add response format to token count if structured output
300+
if response_format:
301+
full_text += json.dumps(response_format)
302+
303+
# Estimate tokens
304+
input_tokens = token_count_simple(full_text)
298305

299306
# Calculate costs using tokencost
300307
input_cost = float(calculate_cost_by_tokens(

batchata/utils/pdf.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -217,39 +217,46 @@ def estimate_pdf_tokens(path: str | Path, prompt: Optional[str] = None,
217217
"""
218218
Estimate token count for a PDF file.
219219
220-
This is a generic utility that can be used by any provider to estimate
221-
tokens for PDF processing.
220+
Provider-specific tokens per page estimates:
221+
- Anthropic: 1,500-3,000 tokens/page (default: 2000)
222+
- Gemini: ~258 tokens/page
223+
- OpenAI: 300-1,280 tokens/page (use: 1000)
222224
223225
Args:
224226
path: Path to the PDF file
225227
prompt: Optional prompt to include in token count
226-
pdf_token_multiplier: Coefficient to apply to extracted text tokens
227-
to account for PDF processing overhead (default: 1.5)
228-
tokens_per_page: Estimated tokens per page for image-based PDFs (default: 2000)
228+
pdf_token_multiplier: Deprecated, kept for compatibility
229+
tokens_per_page: Tokens per page estimate (default: 2000 for Anthropic)
229230
230231
Returns:
231232
Estimated token count
232233
"""
233234
from .llm import token_count_simple
234235

235-
page_count, is_textual, extracted_text = get_pdf_info(path)
236-
237-
if is_textual and extracted_text:
238-
# Count tokens from extracted text
239-
base_tokens = token_count_simple(extracted_text)
240-
if prompt:
241-
base_tokens += token_count_simple(prompt)
242-
243-
# Apply multiplier to account for PDF processing overhead
244-
input_tokens = int(base_tokens * pdf_token_multiplier)
245-
logger.debug(f"Textual PDF {path}: {page_count} pages, "
246-
f"base tokens: {base_tokens}, with {pdf_token_multiplier}x multiplier: {input_tokens}")
247-
else:
248-
# Estimate based on page count
249-
input_tokens = page_count * tokens_per_page
250-
if prompt:
251-
input_tokens += token_count_simple(prompt)
252-
logger.debug(f"Image-based PDF {path}: {page_count} pages, "
253-
f"estimated tokens: {input_tokens} ({tokens_per_page} per page)")
254-
255-
return input_tokens
236+
try:
237+
# Get page count
238+
reader = pypdf.PdfReader(str(path))
239+
page_count = len(reader.pages)
240+
241+
# Use provider-specific tokens per page estimate
242+
pdf_tokens = page_count * tokens_per_page
243+
244+
# Add prompt tokens
245+
prompt_tokens = token_count_simple(prompt) if prompt else 0
246+
247+
# Add minimal overhead for PDF processing
248+
PDF_TOKEN_OVERHEAD = 100 # tokens
249+
overhead_tokens = PDF_TOKEN_OVERHEAD
250+
251+
total_tokens = pdf_tokens + prompt_tokens + overhead_tokens
252+
253+
logger.debug(
254+
f"PDF {path}: {page_count} pages × {tokens_per_page} = {pdf_tokens} tokens, "
255+
f"prompt: {prompt_tokens}, total: {total_tokens}"
256+
)
257+
258+
return total_tokens
259+
260+
except Exception as e:
261+
logger.warning(f"Failed to estimate PDF tokens: {e}")
262+
return 0

batchata/utils/rich_progress.py

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from rich.tree import Tree
1010
from rich.text import Text
1111

12+
# Constants
13+
PROGRESS_BAR_WIDTH = 25
14+
1215

1316
class RichBatchProgressDisplay:
1417
"""Rich-based progress display for batch runs."""
@@ -48,7 +51,7 @@ def start(self, stats: Dict, config: Dict):
4851
self._create_display(),
4952
console=self.console,
5053
refresh_per_second=4, # Reduced refresh rate to avoid flicker
51-
auto_refresh=True
54+
auto_refresh=False # Disable auto-refresh to prevent race conditions with manual updates
5255
)
5356
self.live.start()
5457

@@ -68,9 +71,11 @@ def update(self, stats: Dict, batch_data: Dict, elapsed_time: float):
6871
# Advance spinner
6972
self._spinner_index = (self._spinner_index + 1) % len(self._spinner_frames)
7073

71-
# Update live display
74+
# Update live display (synchronized to prevent race conditions)
7275
if self.live:
7376
self.live.update(self._create_display())
77+
# Force refresh since auto_refresh is disabled
78+
self.live.refresh()
7479

7580
def stop(self):
7681
"""Stop the live progress display."""
@@ -138,37 +143,17 @@ def _create_display(self) -> Tree:
138143
is_last = idx == num_batches - 1
139144
tree_symbol = "└─" if is_last else "├─"
140145

141-
# Format progress bar with better styling
142-
progress_pct = (completed / total) if total > 0 else 0
143-
filled_width = int(progress_pct * 25)
146+
# Extract job counts
147+
failed_count = batch_info.get('failed', 0)
148+
success_count = completed
149+
total_processed = success_count + failed_count
150+
progress_pct = (total_processed / total) if total > 0 else 0
144151

145-
if status == 'complete':
146-
bar = "[bold green]" + "━" * 25 + "[/bold green]"
147-
elif status == 'failed':
148-
bar = "[bold red]" + "━" * 25 + "[/bold red]"
149-
elif status == 'cancelled':
150-
bar = "[bold yellow]" + "━" * filled_width + "[/bold yellow]"
151-
if filled_width < 25:
152-
bar += "[dim yellow]" + "━" * (25 - filled_width) + "[/dim yellow]"
153-
elif status == 'running':
154-
bar = "[bold blue]" + "━" * filled_width + "[/bold blue]"
155-
if filled_width < 25:
156-
bar += "[blue]╸[/blue]" + "[dim white]" + "━" * (24 - filled_width) + "[/dim white]"
157-
else:
158-
bar = "[dim white]" + "━" * 25 + "[/dim white]"
152+
# Create progress bar based on status
153+
bar = self._create_progress_bar(status, success_count, failed_count, total, progress_pct)
159154

160-
# Format status with better colors and fixed width
161-
if status == 'complete':
162-
status_text = "[bold green]Ended[/bold green] "
163-
elif status == 'failed':
164-
status_text = "[bold red]Failed[/bold red] "
165-
elif status == 'cancelled':
166-
status_text = "[bold yellow]Cancelled[/bold yellow]"
167-
elif status == 'running':
168-
spinner = self._spinner_frames[self._spinner_index]
169-
status_text = f"[bold blue]{spinner} Running[/bold blue]"
170-
else:
171-
status_text = "[dim]Pending[/dim]"
155+
# Format status text
156+
status_text = self._format_status_text(status, failed_count)
172157

173158
# Calculate elapsed time
174159
start_time = batch_info.get('start_time')
@@ -196,7 +181,7 @@ def _create_display(self) -> Tree:
196181
else:
197182
time_str = "-:--:--"
198183

199-
# Format percentage
184+
# Format percentage based on total processed (successful + failed)
200185
percentage = int(progress_pct * 100)
201186

202187
# Get output filenames if completed
@@ -226,11 +211,11 @@ def _create_display(self) -> Tree:
226211
else:
227212
cost_text = f"${cost:>5.3f}"
228213

229-
# Create the batch line with proper spacing
214+
# Create the batch line
215+
display_stats = self._get_display_stats(status, success_count, failed_count, total)
230216
batch_line = (
231217
f"{provider} {batch_id:<18} {bar} "
232-
f"{completed:>2}/{total:<2} {percentage:>3}% "
233-
f"{status_text} "
218+
f"{display_stats['completed']:>2}/{total:<2} ({display_stats['percentage']}% done) {status_text:<15} "
234219
f"{cost_text} "
235220
f"{time_str:>8}"
236221
)
@@ -275,4 +260,85 @@ def _create_display(self) -> Tree:
275260
footer = " │ ".join(footer_parts)
276261
tree.add(f"\n[dim]{footer}[/dim]")
277262

278-
return tree
263+
return tree
264+
265+
def _create_progress_bar(self, status: str, success_count: int, failed_count: int, total: int, progress_pct: float) -> str:
266+
"""Create a progress bar showing success/failure proportions."""
267+
268+
if status == 'complete':
269+
return f"[bold green]{'━' * PROGRESS_BAR_WIDTH}[/bold green]"
270+
271+
if status == 'failed':
272+
return self._create_mixed_bar(success_count, failed_count, total, PROGRESS_BAR_WIDTH)
273+
274+
if status == 'cancelled':
275+
filled = int(progress_pct * PROGRESS_BAR_WIDTH)
276+
return f"[bold yellow]{'━' * filled}[/bold yellow][dim yellow]{'━' * (PROGRESS_BAR_WIDTH - filled)}[/dim yellow]"
277+
278+
if status == 'running':
279+
filled = int(progress_pct * PROGRESS_BAR_WIDTH)
280+
if filled < PROGRESS_BAR_WIDTH:
281+
return f"[bold blue]{'━' * filled}[/bold blue][blue]╸[/blue][dim white]{'━' * (PROGRESS_BAR_WIDTH - filled - 1)}[/dim white]"
282+
return f"[bold blue]{'━' * PROGRESS_BAR_WIDTH}[/bold blue]"
283+
284+
# Pending
285+
return f"[dim white]{'━' * PROGRESS_BAR_WIDTH}[/dim white]"
286+
287+
def _create_mixed_bar(self, success_count: int, failed_count: int, total: int, bar_width: int) -> str:
288+
"""Create a bar showing green (success) and red (failed) proportions."""
289+
if total == 0:
290+
return f"[dim white]{'━' * bar_width}[/dim white]"
291+
292+
# Use integer division to calculate base widths
293+
success_width = (success_count * bar_width) // total
294+
failed_width = (failed_count * bar_width) // total
295+
296+
# Distribute remainder to maintain exact bar_width
297+
remainder = bar_width - success_width - failed_width
298+
if remainder > 0:
299+
# Distribute remainder based on which segment has larger fractional part
300+
success_fraction = (success_count * bar_width) % total
301+
failed_fraction = (failed_count * bar_width) % total
302+
303+
if success_fraction >= failed_fraction:
304+
success_width += remainder
305+
else:
306+
failed_width += remainder
307+
308+
# Build the bar
309+
bar_parts = []
310+
if success_width > 0:
311+
bar_parts.append(f"[bold green]{'━' * success_width}[/bold green]")
312+
if failed_width > 0:
313+
bar_parts.append(f"[bold red]{'━' * failed_width}[/bold red]")
314+
315+
return "".join(bar_parts)
316+
317+
def _format_status_text(self, status: str, failed_count: int) -> str:
318+
"""Format the status text with appropriate colors and details."""
319+
if status == 'complete':
320+
return "[bold green]Complete[/bold green]"
321+
elif status == 'failed':
322+
if failed_count > 0:
323+
return f"[bold red]Failed ({failed_count})[/bold red]"
324+
return "[bold red]Failed[/bold red]"
325+
elif status == 'cancelled':
326+
return "[bold yellow]Cancelled[/bold yellow]"
327+
elif status == 'running':
328+
spinner = self._spinner_frames[self._spinner_index]
329+
return f"[bold blue]{spinner} Running[/bold blue]"
330+
else:
331+
return "[dim]Pending[/dim]"
332+
333+
def _get_display_stats(self, status: str, success_count: int, failed_count: int, total: int) -> dict:
334+
"""Get the display statistics (completed count and percentage)."""
335+
if status == 'failed' and failed_count > 0:
336+
# For failed batches, show success count to make it clear
337+
completed = success_count
338+
percentage = int((success_count / total) * 100) if total > 0 else 0
339+
else:
340+
# For other statuses, show total processed
341+
completed = success_count + failed_count
342+
percentage = int((completed / total) * 100) if total > 0 else 0
343+
344+
return {'completed': completed, 'percentage': percentage}

0 commit comments

Comments
 (0)