diff --git a/docs/howtos/customizations/customize_models.md b/docs/howtos/customizations/customize_models.md index 2ee6b7dfe6..f9e0978a0f 100644 --- a/docs/howtos/customizations/customize_models.md +++ b/docs/howtos/customizations/customize_models.md @@ -9,6 +9,9 @@ Ragas may use a LLM and or Embedding for evaluation and synthetic data generatio - If you are using Langchain, you can pass the Langchain LLM and Embeddings directly and Ragas will wrap it with `LangchainLLMWrapper` or `LangchainEmbeddingsWrapper` as needed. +!!! tip "Batch API Support" + OpenAI models (ChatOpenAI, AzureChatOpenAI) automatically support [Batch Evaluation](../metrics/batch_evaluation.md) for up to 50% cost savings on large-scale evaluations. The `LangchainLLMWrapper` automatically detects batch support and enables cost-optimized evaluation workflows. + ## Examples - [Azure OpenAI](#azure-openai) diff --git a/docs/howtos/customizations/index.md b/docs/howtos/customizations/index.md index 8223c9d82b..d0ecd62e7a 100644 --- a/docs/howtos/customizations/index.md +++ b/docs/howtos/customizations/index.md @@ -14,6 +14,7 @@ How to customize various aspects of Ragas to suit your needs. - [Adapt metrics to target language](./metrics/_metrics_language_adaptation.md) - [Trace evaluations with Observability tools](metrics/tracing.md) - [Train and align metric](./metrics/train_your_own_metric.md) +- [Batch evaluation for cost optimization](./metrics/batch_evaluation.md) 🆕 ## Testset Generation diff --git a/docs/howtos/customizations/metrics/_cost.md b/docs/howtos/customizations/metrics/_cost.md index 3cd5501a5e..3430d4c57f 100644 --- a/docs/howtos/customizations/metrics/_cost.md +++ b/docs/howtos/customizations/metrics/_cost.md @@ -1,6 +1,43 @@ # Understand Cost and Usage of Operations -When using LLMs for evaluation and test set generation, cost will be an important factor. Ragas provides you some tools to help you with that. +When using LLMs for evaluation and test set generation, cost will be an important factor. Ragas provides several tools to help you optimize costs, including **Batch API support** for up to 50% savings on large-scale evaluations. + +## Cost Optimization Strategies + +### 1. Use Batch API for Large Evaluations (50% Savings) + +For non-urgent evaluation workloads, Ragas supports OpenAI's Batch API which provides 50% cost savings: + +```python +from ragas.batch_evaluation import BatchEvaluator, estimate_batch_cost_savings +from ragas.metrics import Faithfulness +from langchain_openai import ChatOpenAI +from ragas.llms import LangchainLLMWrapper + +# Setup batch-capable LLM +llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o-mini")) +faithfulness = Faithfulness(llm=llm) + +# Estimate cost savings +cost_info = estimate_batch_cost_savings( + sample_count=1000, + metrics=[faithfulness], + regular_cost_per_1k_tokens=0.15, # GPT-4o-mini cost + batch_discount=0.5 # 50% savings +) + +print(f"Regular cost: ${cost_info['regular_cost']}") +print(f"Batch cost: ${cost_info['batch_cost']}") +print(f"Savings: ${cost_info['savings']} ({cost_info['savings_percentage']}%)") + +# Run batch evaluation +evaluator = BatchEvaluator(metrics=[faithfulness]) +results = evaluator.evaluate(samples, wait_for_completion=True) +``` + +Learn more about [Batch Evaluation](batch_evaluation.md). + +### 2. Monitor Token Usage ## Understanding `TokenUsageParser` @@ -32,15 +69,12 @@ from ragas.cost import get_token_usage_for_openai get_token_usage_for_openai(llm_result) ``` - /opt/homebrew/Caskroom/miniforge/base/envs/ragas/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html - from .autonotebook import tqdm as notebook_tqdm - - - - - - TokenUsage(input_tokens=9, output_tokens=9, model='') +```py +/opt/homebrew/Caskroom/miniforge/base/envs/ragas/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html + from .autonotebook import tqdm as notebook_tqdm +TokenUsage(input_tokens=9, output_tokens=9, model='') +``` You can define your own or import parsers if they are defined. If you would like to suggest parser for LLM providers or contribute your own ones please check out this [issue](https://github.com/explodinggradients/ragas/issues/1151) 🙂. @@ -64,9 +98,9 @@ metric = AspectCriticWithReference( ) ``` - Repo card metadata block was not found. Setting CardData to empty. - - +```py +Repo card metadata block was not found. Setting CardData to empty. +``` ```python from ragas import evaluate @@ -80,38 +114,27 @@ results = evaluate( ) ``` - Evaluating: 100%|██████████| 5/5 [00:01<00:00, 2.81it/s] - - +```py +Evaluating: 100%|██████████| 5/5 [00:01<00:00, 2.81it/s] +``` ```python results.total_tokens() ``` - - - - TokenUsage(input_tokens=5463, output_tokens=355, model='') - +```py +TokenUsage(input_tokens=5463, output_tokens=355, model='') +``` You can compute the cost for each run by passing in the cost per token to `Result.total_cost()` function. In this case GPT-4o costs $5 for 1M input tokens and $15 for 1M output tokens. - ```python results.total_cost(cost_per_input_token=5 / 1e6, cost_per_output_token=15 / 1e6) ``` - - - - 0.03264 - - - - -```python - +```py +0.03264 ``` diff --git a/docs/howtos/customizations/metrics/batch_evaluation.md b/docs/howtos/customizations/metrics/batch_evaluation.md new file mode 100644 index 0000000000..d9de2d3c74 --- /dev/null +++ b/docs/howtos/customizations/metrics/batch_evaluation.md @@ -0,0 +1,305 @@ +# Batch Evaluation for Cost Optimization + +When running large-scale evaluations, cost can be a significant factor. Ragas now supports OpenAI's Batch API, which offers **up to 50% cost savings** compared to regular API calls, making it ideal for non-urgent evaluation workloads. + +## What is Batch Evaluation? + +OpenAI's Batch API allows you to submit multiple requests for asynchronous processing at half the cost of synchronous requests. Batch jobs are processed within 24 hours and have separate rate limits, making them perfect for large-scale evaluations where immediate results aren't required. + +### Key Benefits + +- **50% Cost Savings** on both input and output tokens +- **Higher Rate Limits** that don't interfere with real-time usage +- **Guaranteed Processing** within 24 hours (often much sooner) +- **Large Scale Support** up to 50,000 requests per batch + +## Quick Start + +### Basic Batch Evaluation + +```python +import os +from ragas.batch_evaluation import BatchEvaluator, estimate_batch_cost_savings +from ragas.dataset_schema import SingleTurnSample +from ragas.metrics import Faithfulness +from ragas.llms import LangchainLLMWrapper +from langchain_openai import ChatOpenAI + +# Ensure you have your OpenAI API key set +os.environ["OPENAI_API_KEY"] = "your-openai-api-key" + +# Setup LLM with batch support (automatically detected) +llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o-mini")) +faithfulness = Faithfulness(llm=llm) + +# Prepare your evaluation samples +samples = [ + SingleTurnSample( + user_input="What is the capital of France?", + response="The capital of France is Paris.", + retrieved_contexts=["Paris is the capital city of France."] + ), + # ... more samples +] + +# Create batch evaluator +evaluator = BatchEvaluator(metrics=[faithfulness]) + +# Run batch evaluation (blocks until completion) +results = evaluator.evaluate(samples, wait_for_completion=True) + +# Check results +for result in results: + print(f"Metric: {result.metric_name}") + print(f"Job ID: {result.job_id}") + print(f"Success Rate: {result.success_rate:.2%}") + print(f"Sample Count: {result.sample_count}") +``` + +### Cost Estimation + +Before running batch evaluations, you can estimate your cost savings: + +```python +from ragas.batch_evaluation import estimate_batch_cost_savings + +# Estimate costs for 1000 samples +cost_info = estimate_batch_cost_savings( + sample_count=1000, + metrics=[faithfulness], + regular_cost_per_1k_tokens=0.15, # GPT-4o-mini input cost + batch_discount=0.5 # 50% savings +) + +print(f"Regular API Cost: ${cost_info['regular_cost']}") +print(f"Batch API Cost: ${cost_info['batch_cost']}") +print(f"Total Savings: ${cost_info['savings']} ({cost_info['savings_percentage']}%)") +``` + +### Asynchronous Batch Evaluation + +For non-blocking operations, use async evaluation: + +```python +import asyncio + +async def run_batch_evaluation(): + evaluator = BatchEvaluator(metrics=[faithfulness]) + + # Submit jobs without waiting + results = await evaluator.aevaluate( + samples=samples, + wait_for_completion=False # Don't block + ) + + # Jobs are submitted, check back later + for result in results: + print(f"Submitted job {result.job_id} for {result.metric_name}") + +# Run async evaluation +asyncio.run(run_batch_evaluation()) +``` + +## Checking Batch Support + +Not all LLMs support batch evaluation. Here's how to check: + +```python +# Check if metric supports batch evaluation +if faithfulness.supports_batch_evaluation(): + print(f"✅ {faithfulness.name} supports batch evaluation") +else: + print(f"❌ {faithfulness.name} requires regular API") + +# Check LLM batch support +if llm.supports_batch_api(): + print("✅ LLM supports batch processing") +else: + print("❌ LLM does not support batch processing") +``` + +## Supported Models + +Currently, batch evaluation is supported for: +- OpenAI models (ChatOpenAI, AzureChatOpenAI) +- All metrics that use these LLMs + +### Supported Metrics + +- ✅ Faithfulness (partial support) +- 🔄 More metrics coming soon... + +For metrics not yet supporting batch evaluation, they will automatically fall back to regular API calls. + +## Configuration Options + +### BatchEvaluator Parameters + +```python +evaluator = BatchEvaluator( + metrics=metrics, + max_batch_size=1000, # Max samples per batch + poll_interval=300.0, # Status check interval (5 minutes) + timeout=86400.0 # Max wait time (24 hours) +) +``` + +### Custom Metadata + +Add metadata to track your batch jobs: + +```python +results = evaluator.evaluate( + samples=samples, + metadata={ + "experiment": "model_comparison", + "version": "v1.0", + "dataset": "production_qa" + } +) +``` + +## Best Practices + +### When to Use Batch Evaluation + +✅ **Ideal for:** +- Large-scale evaluations (100+ samples) +- Non-urgent evaluation workloads +- Cost optimization scenarios +- Regular evaluation pipelines + +❌ **Avoid for:** +- Real-time evaluation needs +- Interactive applications +- Small datasets (<50 samples) +- Time-sensitive workflows + +### Optimization Tips + +1. **Batch Size**: Use 1000-5000 samples per batch for optimal performance +2. **Model Selection**: Use cost-effective models like `gpt-4o-mini` +3. **Concurrent Processing**: Submit multiple metrics simultaneously +4. **Monitoring**: Set up logging for long-running jobs + +```python +import logging + +# Enable batch evaluation logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('ragas.batch_evaluation') +``` + +## Error Handling + +```python +try: + results = evaluator.evaluate(samples) + + for result in results: + if result.errors: + print(f"❌ Errors in {result.metric_name}:") + for error in result.errors: + print(f" - {error}") + else: + print(f"✅ {result.metric_name}: {result.success_rate:.2%} success") + +except Exception as e: + print(f"Batch evaluation failed: {e}") +``` + +## Low-Level Batch API + +For advanced use cases, you can use the low-level batch API directly: + +```python +from ragas.llms.batch_api import create_batch_api, BatchRequest +from openai import OpenAI + +# Direct batch API usage +client = OpenAI() +batch_api = create_batch_api(client) + +# Create custom requests +requests = [ + BatchRequest( + custom_id="eval-1", + body={ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Evaluate this response..."}] + } + ) +] + +# Submit batch job +batch_job = batch_api.create_batch(requests) +print(f"Batch job created: {batch_job.batch_id}") + +# Monitor progress +status = batch_job.get_status() +print(f"Status: {status.value}") + +# Retrieve results when complete +if status.value == "completed": + results = batch_job.get_results() + for result in results: + print(f"Response for {result.custom_id}: {result.response}") +``` + +## Troubleshooting + +### Common Issues + +**Issue**: "Batch API not supported for this LLM" +```python +# Solution: Use OpenAI-based LLM +from langchain_openai import ChatOpenAI +llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o-mini")) +``` + +**Issue**: "Metric does not support batch evaluation" +```python +# Solution: Check metric support or wait for future updates +if not metric.supports_batch_evaluation(): + print(f"Metric {metric.name} will use regular API") +``` + +**Issue**: Timeout waiting for batch completion +```python +# Solution: Use non-blocking evaluation or increase timeout +results = evaluator.evaluate( + samples, + wait_for_completion=False # Don't wait +) +# Or increase timeout +evaluator = BatchEvaluator(timeout=172800.0) # 48 hours +``` + +## Migration from Regular Evaluation + +Converting existing evaluations to use batch processing is simple: + +### Before (Regular API) +```python +from ragas import evaluate +from ragas.metrics import Faithfulness + +results = evaluate( + dataset=eval_dataset, + metrics=[Faithfulness(llm=llm)] +) +``` + +### After (Batch API) +```python +from ragas.batch_evaluation import BatchEvaluator +from ragas.metrics import Faithfulness + +# Convert dataset to samples if needed +samples = [sample for sample in eval_dataset] + +evaluator = BatchEvaluator(metrics=[Faithfulness(llm=llm)]) +results = evaluator.evaluate(samples) +``` + +The batch API provides significant cost savings while maintaining the same evaluation quality, making it an excellent choice for large-scale evaluation workloads. \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index b1cbce9e09..c2442c6645 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -103,6 +103,7 @@ nav: - Write your own Metrics - (advanced): howtos/customizations/metrics/_write_your_own_metric_advanced.md - Train and Align Metrics: howtos/customizations/metrics/train_your_own_metric.md - Systematic Approach for Prompt Optimization: howtos/applications/prompt_optimization.md + - Batch Evaluation: howtos/customizations/metrics/batch_evaluation.md - Testset Generation: - Non-English Testset Generation: howtos/customizations/testgenerator/_language_adaptation.md - Persona Generation: howtos/customizations/testgenerator/_persona_generator.md diff --git a/pyproject.toml b/pyproject.toml index a30270491b..1ae6abd5b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,10 @@ ragas-examples = { workspace = true } addopts = "-n 0" asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] +python_classes = ["Test*"] +filterwarnings = [ + "ignore:cannot collect test class 'TestsetGenerator':pytest.PytestCollectionWarning" +] [dependency-groups] # Full dev dependencies with all features (used by make install) diff --git a/src/ragas/batch_evaluation.py b/src/ragas/batch_evaluation.py new file mode 100644 index 0000000000..5420aaaa74 --- /dev/null +++ b/src/ragas/batch_evaluation.py @@ -0,0 +1,437 @@ +""" +Batch evaluation utilities for cost-effective metric evaluation using OpenAI Batch API. + +This module provides high-level utilities for running Ragas metrics in batch mode, +offering significant cost savings (up to 50%) for large-scale evaluations. +""" + +from __future__ import annotations + +import logging +import typing as t +from dataclasses import dataclass, field + +from ragas.dataset_schema import MultiTurnSample, SingleTurnSample +from ragas.metrics.base import MetricWithLLM + +if t.TYPE_CHECKING: + from ragas.llms.batch_api import BatchResponse + +logger = logging.getLogger(__name__) + + +@dataclass +class BatchEvaluationResult: + """Results from a batch evaluation job.""" + + metric_name: str + job_id: str + sample_count: int + responses: t.List[BatchResponse] + scores: t.Optional[t.List[t.Optional[float]]] = None + errors: t.List[str] = field(default_factory=list) + + @property + def success_rate(self) -> float: + """Calculate the success rate of the batch job.""" + if not self.responses: + return 0.0 + successful = sum(1 for resp in self.responses if resp.error is None) + return successful / len(self.responses) + + @property + def average_score(self) -> t.Optional[float]: + """Calculate average score if scores are available.""" + if not self.scores: + return None + valid_scores = [s for s in self.scores if s is not None] + return sum(valid_scores) / len(valid_scores) if valid_scores else None + + +class BatchEvaluator: + """High-level interface for batch evaluation using OpenAI Batch API.""" + + def __init__( + self, + metrics: t.List[MetricWithLLM], + max_batch_size: int = 1000, + poll_interval: float = 300.0, # 5 minutes + timeout: float = 86400.0, # 24 hours + ): + """ + Initialize batch evaluator. + + Args: + metrics: List of metrics to evaluate + max_batch_size: Maximum samples per batch job + poll_interval: Polling interval for batch job status (seconds) + timeout: Maximum time to wait for batch completion (seconds) + """ + self.metrics = metrics + self.max_batch_size = max_batch_size + self.poll_interval = poll_interval + self.timeout = timeout + + # Validate that all metrics support batch evaluation + for metric in metrics: + if not metric.supports_batch_evaluation(): + raise ValueError( + f"Metric '{metric.name}' does not support batch evaluation. " + "Ensure it uses an LLM that supports OpenAI Batch API." + ) + + def evaluate( + self, + samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]], + wait_for_completion: bool = True, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> t.List[BatchEvaluationResult]: + """ + Run batch evaluation on samples. + + Args: + samples: Samples to evaluate + wait_for_completion: Whether to wait for jobs to complete + metadata: Optional metadata for batch jobs + + Returns: + List of batch evaluation results + """ + if len(samples) > self.max_batch_size: + raise ValueError( + f"Sample count {len(samples)} exceeds maximum batch size {self.max_batch_size}" + ) + + results = [] + jobs = [] + + # Create batch jobs for each metric + for metric in self.metrics: + logger.info(f"Creating batch job for metric: {metric.name}") + + job = metric.create_batch_evaluation_job( + samples=samples, batch_size=self.max_batch_size, metadata=metadata + ) + jobs.append((metric, job)) + + # Wait for completion if requested + if wait_for_completion: + for metric, job in jobs: + logger.info( + f"Waiting for batch job completion: {metric.name} (ID: {job.batch_id})" + ) + + status = job.wait_for_completion( + poll_interval=self.poll_interval, timeout=self.timeout + ) + + if status.value == "completed": + responses = job.get_results() + result = BatchEvaluationResult( + metric_name=metric.name, + job_id=job.batch_id, + sample_count=len(samples), + responses=responses, + ) + + # Process responses to extract scores + try: + result.scores = self._extract_scores(metric, responses) + except Exception as e: + result.errors.append(f"Score extraction failed: {str(e)}") + + results.append(result) + else: + # Job failed or was cancelled + result = BatchEvaluationResult( + metric_name=metric.name, + job_id=job.batch_id, + sample_count=len(samples), + responses=[], + errors=[f"Batch job failed with status: {status.value}"], + ) + results.append(result) + else: + # Return results with pending jobs + for metric, job in jobs: + result = BatchEvaluationResult( + metric_name=metric.name, + job_id=job.batch_id, + sample_count=len(samples), + responses=[], + ) + results.append(result) + + return results + + async def aevaluate( + self, + samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]], + wait_for_completion: bool = True, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> t.List[BatchEvaluationResult]: + """Async version of evaluate.""" + if len(samples) > self.max_batch_size: + raise ValueError( + f"Sample count {len(samples)} exceeds maximum batch size {self.max_batch_size}" + ) + + results = [] + jobs = [] + + # Create batch jobs for each metric + for metric in self.metrics: + logger.info(f"Creating batch job for metric: {metric.name}") + + job = await metric.acreate_batch_evaluation_job( + samples=samples, batch_size=self.max_batch_size, metadata=metadata + ) + jobs.append((metric, job)) + + # Wait for completion if requested + if wait_for_completion: + for metric, job in jobs: + logger.info( + f"Waiting for batch job completion: {metric.name} (ID: {job.batch_id})" + ) + + status = await job.await_completion( + poll_interval=self.poll_interval, timeout=self.timeout + ) + + if status.value == "completed": + responses = await job.aget_results() + result = BatchEvaluationResult( + metric_name=metric.name, + job_id=job.batch_id, + sample_count=len(samples), + responses=responses, + ) + + try: + result.scores = self._extract_scores(metric, responses) + except Exception as e: + result.errors.append(f"Score extraction failed: {str(e)}") + + results.append(result) + else: + result = BatchEvaluationResult( + metric_name=metric.name, + job_id=job.batch_id, + sample_count=len(samples), + responses=[], + errors=[f"Batch job failed with status: {status.value}"], + ) + results.append(result) + else: + for metric, job in jobs: + result = BatchEvaluationResult( + metric_name=metric.name, + job_id=job.batch_id, + sample_count=len(samples), + responses=[], + ) + results.append(result) + + return results + + def _extract_scores( + self, metric: MetricWithLLM, responses: t.List[BatchResponse] + ) -> t.List[t.Optional[float]]: + """ + Extract scores from batch responses. + + This method parses the batch responses and attempts to extract numerical scores + based on the metric's output format. It handles common patterns like JSON + responses with verdict fields or direct numerical outputs. + """ + scores = [] + + for response in responses: + score = None + + if response.error is not None: + logger.error( + f"Error in batch response {response.custom_id}: {response.error}" + ) + scores.append(None) + continue + + if response.response is None: + logger.warning(f"No response content for {response.custom_id}") + scores.append(None) + continue + + try: + # Extract content from OpenAI response format + content = self._extract_content_from_response(response.response) + if content is None: + scores.append(None) + continue + + # Parse as structured output first (JSON) + score = self._parse_structured_score(content, metric.name) + + # Parse raw text for score patterns + if score is None: + score = self._parse_text_score(content) + + scores.append(score) + + except Exception as e: + logger.error( + f"Failed to extract score from response {response.custom_id}: {e}" + ) + scores.append(None) + + return scores + + def _extract_content_from_response( + self, response: t.Dict[str, t.Any] + ) -> t.Optional[str]: + """Extract text content from OpenAI API response format.""" + try: + # Standard OpenAI chat completion response format + choices = response.get("choices", []) + if choices and len(choices) > 0: + message = choices[0].get("message", {}) + return message.get("content", "") + return None + except Exception as e: + logger.error(f"Failed to extract content from response: {e}") + return None + + def _parse_structured_score( + self, content: str, metric_name: str + ) -> t.Optional[float]: + """Parse structured JSON output to extract score.""" + try: + import json + import re + + # Clean the content to extract JSON + content = content.strip() + + # Look for JSON blocks + json_match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if json_match: + content = json_match.group(1) + elif content.startswith("{") and content.endswith("}"): + pass # Already clean JSON + else: + # Look for JSON object in text + json_match = re.search(r"\{[^{}]*\}", content) + if json_match: + content = json_match.group(0) + else: + return None + + parsed = json.loads(content) + + # Common patterns for different metrics + score_patterns = [ + "score", + "verdict", + "faithfulness_score", + "relevance_score", + "correctness_score", + "precision", + "recall", + "f1_score", + ] + + for pattern in score_patterns: + if pattern in parsed: + value = parsed[pattern] + if isinstance(value, (int, float)): + return float(value) + elif isinstance(value, str) and value.replace(".", "").isdigit(): + return float(value) + + # For faithfulness-like metrics, calculate score from statements + if "statements" in parsed and isinstance(parsed["statements"], list): + statements = parsed["statements"] + if statements: + verdicts = [] + for stmt in statements: + if isinstance(stmt, dict) and "verdict" in stmt: + verdict = stmt["verdict"] + if isinstance(verdict, (int, float)): + verdicts.append(verdict) + + if verdicts: + return sum(verdicts) / len(verdicts) + + return None + + except json.JSONDecodeError: + return None + except Exception as e: + logger.debug(f"Error parsing structured score: {e}") + return None + + def _parse_text_score(self, content: str) -> t.Optional[float]: + """Parse raw text content to find numerical scores.""" + import re + + # Look for common score patterns + patterns = [ + r"score[:\s]*([0-9]*\.?[0-9]+)", + r"verdict[:\s]*([0-9]*\.?[0-9]+)", + r"rating[:\s]*([0-9]*\.?[0-9]+)", + r"([0-9]*\.?[0-9]+)(?:\s*/\s*[0-9]+)?", # Simple number or fraction + ] + + for pattern in patterns: + matches = re.findall(pattern, content.lower()) + if matches: + try: + score = float(matches[0]) + # Validate score is in reasonable range (0-1 or 0-10) + if 0 <= score <= 1 or 0 <= score <= 10: + return score + except (ValueError, IndexError): + continue + + return None + + +def create_batch_evaluator( + metrics: t.List[MetricWithLLM], **kwargs: t.Any +) -> BatchEvaluator: + """Factory function to create a batch evaluator.""" + return BatchEvaluator(metrics=metrics, **kwargs) + + +def estimate_batch_cost_savings( + sample_count: int, + metrics: t.List[MetricWithLLM], + regular_cost_per_1k_tokens: float = 0.03, + batch_discount: float = 0.5, +) -> t.Dict[str, float]: + """ + Estimate cost savings from using batch API. + + Args: + sample_count: Number of samples to evaluate + metrics: List of metrics to run + regular_cost_per_1k_tokens: Regular API cost per 1K tokens + batch_discount: Batch API discount (0.5 = 50% savings) + + Returns: + Dictionary with cost estimates + """ + estimated_tokens_per_sample = 500 + total_tokens = sample_count * len(metrics) * estimated_tokens_per_sample + + regular_cost = (total_tokens / 1000) * regular_cost_per_1k_tokens + batch_cost = regular_cost * (1 - batch_discount) + savings = regular_cost - batch_cost + + return { + "regular_cost": round(regular_cost, 4), + "batch_cost": round(batch_cost, 4), + "savings": round(savings, 4), + "savings_percentage": round(batch_discount * 100, 1), + "estimated_tokens": total_tokens, + } diff --git a/src/ragas/llms/__init__.py b/src/ragas/llms/__init__.py index 62d50f2b07..cf823e4675 100644 --- a/src/ragas/llms/__init__.py +++ b/src/ragas/llms/__init__.py @@ -8,6 +8,15 @@ instructor_llm_factory, llm_factory, ) +from ragas.llms.batch_api import ( + BatchEndpoint, + BatchJob, + BatchRequest, + BatchResponse, + BatchStatus, + OpenAIBatchAPI, + create_batch_api, +) from ragas.llms.haystack_wrapper import HaystackLLMWrapper __all__ = [ @@ -20,4 +29,12 @@ "InstructorTypeVar", "instructor_llm_factory", "llm_factory", + # Batch API + "BatchEndpoint", + "BatchJob", + "BatchRequest", + "BatchResponse", + "BatchStatus", + "OpenAIBatchAPI", + "create_batch_api", ] diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 2b356b257f..c4d5325b7a 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -30,6 +30,8 @@ from langchain_core.prompt_values import PromptValue from llama_index.core.base.llms.base import BaseLLM + from ragas.llms.batch_api import BatchJob, BatchRequest, OpenAIBatchAPI + logger = logging.getLogger(__name__) @@ -59,6 +61,7 @@ class BaseRagasLLM(ABC): run_config: RunConfig = field(default_factory=RunConfig, repr=False) multiple_completion_supported: bool = field(default=False, repr=False) cache: t.Optional[CacheInterface] = field(default=None, repr=False) + batch_api_support: bool = field(default=False, repr=False) def __post_init__(self): # If a cache_backend is provided, wrap the implementation methods at construction time. @@ -127,6 +130,32 @@ async def generate( raise LLMDidNotFinishException() return result + def supports_batch_api(self) -> bool: + """Check if this LLM supports batch API operations.""" + return self.batch_api_support + + def create_batch_job( + self, + prompts: t.List[PromptValue], + n: int = 1, + temperature: t.Optional[float] = None, + stop: t.Optional[t.List[str]] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "BatchJob": + """Create a batch job for multiple prompts (sync version).""" + raise NotImplementedError("Batch API not implemented for this LLM") + + async def acreate_batch_job( + self, + prompts: t.List[PromptValue], + n: int = 1, + temperature: t.Optional[float] = None, + stop: t.Optional[t.List[str]] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "BatchJob": + """Create a batch job for multiple prompts (async version).""" + raise NotImplementedError("Batch API not implemented for this LLM") + class LangchainLLMWrapper(BaseRagasLLM): """ @@ -153,6 +182,12 @@ def __init__( # Certain LLMs (e.g., OpenAI o1 series) do not support temperature self.bypass_temperature = bypass_temperature + # Check if batch API is supported (OpenAI models only for now) + self.batch_api_support = isinstance( + langchain_llm, (ChatOpenAI, AzureChatOpenAI) + ) + self._batch_api: t.Optional["OpenAIBatchAPI"] = None + def is_finished(self, response: LLMResult) -> bool: """ Parse the response to check if the LLM finished by checking the finish_reason @@ -324,6 +359,121 @@ def set_run_config(self, run_config: RunConfig): self.langchain_llm.request_timeout = run_config.timeout self.run_config.exception_types = RateLimitError + def _get_batch_api(self) -> "OpenAIBatchAPI": + """Get or create OpenAI Batch API instance.""" + if not self.supports_batch_api(): + raise ValueError("Batch API not supported for this LLM") + + if self._batch_api is None: + # Get OpenAI client from the LangChain model + openai_client = None + if isinstance(self.langchain_llm, ChatOpenAI): + openai_client = self.langchain_llm.client + elif isinstance(self.langchain_llm, AzureChatOpenAI): + openai_client = self.langchain_llm.client + + if openai_client is None: + raise ValueError("Could not extract OpenAI client from LangChain model") + + from ragas.llms.batch_api import create_batch_api + + self._batch_api = create_batch_api(openai_client) + + return self._batch_api + + def create_batch_job( + self, + prompts: t.List[PromptValue], + n: int = 1, + temperature: t.Optional[float] = None, + stop: t.Optional[t.List[str]] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "BatchJob": + """Create a batch job for multiple prompts (sync version).""" + if not self.supports_batch_api(): + raise ValueError("Batch API not supported for this LLM") + + batch_api = self._get_batch_api() + + # Convert PromptValue to batch requests + batch_requests = self._create_batch_requests_from_prompts( + prompts, n, temperature, stop + ) + + return batch_api.create_batch(requests=batch_requests, metadata=metadata) + + async def acreate_batch_job( + self, + prompts: t.List[PromptValue], + n: int = 1, + temperature: t.Optional[float] = None, + stop: t.Optional[t.List[str]] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "BatchJob": + """Create a batch job for multiple prompts (async version).""" + if not self.supports_batch_api(): + raise ValueError("Batch API not supported for this LLM") + + batch_api = self._get_batch_api() + + # Convert PromptValue to batch requests + batch_requests = self._create_batch_requests_from_prompts( + prompts, n, temperature, stop + ) + + return await batch_api.acreate_batch(requests=batch_requests, metadata=metadata) + + def _create_batch_requests_from_prompts( + self, + prompts: t.List[PromptValue], + n: int = 1, + temperature: t.Optional[float] = None, + stop: t.Optional[t.List[str]] = None, + ) -> t.List["BatchRequest"]: + """Convert PromptValue objects to batch requests.""" + from ragas.llms.batch_api import BatchEndpoint, BatchRequest + + if temperature is None: + temperature = self.get_temperature(n) + + # Get model name + model_name = getattr(self.langchain_llm, "model_name", "gpt-3.5-turbo") + + requests = [] + for i, prompt in enumerate(prompts): + # Convert PromptValue to messages format + if hasattr(prompt, "to_messages"): + messages = [ + {"role": msg.type, "content": msg.content} + for msg in prompt.to_messages() + ] + else: + # Fallback for string prompts + messages = [{"role": "user", "content": str(prompt)}] + + body = { + "model": model_name, + "messages": messages, + "n": n, + "temperature": temperature, + } + + if stop is not None: + body["stop"] = stop + + # Remove unsupported parameters for certain models + if self.bypass_temperature: + body.pop("temperature", None) + + request = BatchRequest( + custom_id=f"ragas-batch-{i}", + url=BatchEndpoint.CHAT_COMPLETIONS.value, + body=body, + ) + requests.append(request) + + return requests + def __repr__(self) -> str: return f"{self.__class__.__name__}(langchain_llm={self.langchain_llm.__class__.__name__}(...))" diff --git a/src/ragas/llms/batch_api.py b/src/ragas/llms/batch_api.py new file mode 100644 index 0000000000..66e0087700 --- /dev/null +++ b/src/ragas/llms/batch_api.py @@ -0,0 +1,386 @@ +""" +OpenAI Batch API implementation for cost-effective evaluation. + +This module provides support for OpenAI's Batch API, enabling up to 50% cost savings +for non-urgent evaluation workloads. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import tempfile +import time +import typing as t +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path + +if t.TYPE_CHECKING: + from openai import AsyncOpenAI, OpenAI + +logger = logging.getLogger(__name__) + + +class BatchStatus(Enum): + """Batch job status enumeration.""" + + VALIDATING = "validating" + FAILED = "failed" + IN_PROGRESS = "in_progress" + FINALIZING = "finalizing" + COMPLETED = "completed" + EXPIRED = "expired" + CANCELLING = "cancelling" + CANCELLED = "cancelled" + + +class BatchEndpoint(Enum): + """Supported batch API endpoints.""" + + CHAT_COMPLETIONS = "/v1/chat/completions" + COMPLETIONS = "/v1/completions" + EMBEDDINGS = "/v1/embeddings" + + +@dataclass +class BatchRequest: + """Represents a single batch request.""" + + custom_id: str + method: str = "POST" + url: str = BatchEndpoint.CHAT_COMPLETIONS.value + body: t.Dict[str, t.Any] = field(default_factory=dict) + + +@dataclass +class BatchResponse: + """Represents a single batch response.""" + + id: str + custom_id: str + response: t.Optional[t.Dict[str, t.Any]] = None + error: t.Optional[t.Dict[str, t.Any]] = None + + +class BatchJob: + """Represents an OpenAI batch job.""" + + def __init__( + self, + client: t.Union[OpenAI, AsyncOpenAI], + batch_id: str, + endpoint: str = BatchEndpoint.CHAT_COMPLETIONS.value, + completion_window: str = "24h", + ): + self.client = client + self.batch_id = batch_id + self.endpoint = endpoint + self.completion_window = completion_window + self._is_async = self._check_async_client(client) + + def _check_async_client(self, client: t.Any) -> bool: + """Check if the client is async.""" + return hasattr(client, "__aenter__") or "Async" in client.__class__.__name__ + + def get_status(self) -> BatchStatus: + """Get the current status of the batch job.""" + if self._is_async: + raise RuntimeError("Use aget_status() for async clients") + + batch = self.client.batches.retrieve(self.batch_id) # type: ignore[misc] + return BatchStatus(batch.status) # type: ignore[misc] + + async def aget_status(self) -> BatchStatus: + """Asynchronously get the current status of the batch job.""" + if not self._is_async: + raise RuntimeError("Use get_status() for sync clients") + + batch = await self.client.batches.retrieve(self.batch_id) # type: ignore[misc] + return BatchStatus(batch.status) # type: ignore[misc] + + def wait_for_completion( + self, poll_interval: float = 30.0, timeout: float = 86400.0 + ) -> BatchStatus: + """Wait for batch job completion with polling.""" + if self._is_async: + raise RuntimeError("Use await_completion() for async clients") + + start_time = time.time() + while True: + status = self.get_status() + + if status in [ + BatchStatus.COMPLETED, + BatchStatus.FAILED, + BatchStatus.EXPIRED, + BatchStatus.CANCELLED, + ]: + return status + + if time.time() - start_time > timeout: + raise TimeoutError( + f"Batch job {self.batch_id} did not complete within {timeout} seconds" + ) + + logger.info( + f"Batch job {self.batch_id} status: {status.value}. Waiting {poll_interval}s..." + ) + time.sleep(poll_interval) + + async def await_completion( + self, poll_interval: float = 30.0, timeout: float = 86400.0 + ) -> BatchStatus: + """Asynchronously wait for batch job completion with polling.""" + if not self._is_async: + raise RuntimeError("Use wait_for_completion() for sync clients") + + start_time = time.time() + while True: + status = await self.aget_status() + + if status in [ + BatchStatus.COMPLETED, + BatchStatus.FAILED, + BatchStatus.EXPIRED, + BatchStatus.CANCELLED, + ]: + return status + + if time.time() - start_time > timeout: + raise TimeoutError( + f"Batch job {self.batch_id} did not complete within {timeout} seconds" + ) + + logger.info( + f"Batch job {self.batch_id} status: {status.value}. Waiting {poll_interval}s..." + ) + await asyncio.sleep(poll_interval) + + def get_results(self) -> t.List[BatchResponse]: + """Retrieve and parse batch job results.""" + if self._is_async: + raise RuntimeError("Use aget_results() for async clients") + + batch = self.client.batches.retrieve(self.batch_id) # type: ignore[misc] + + if batch.status != "completed": # type: ignore[misc] + raise ValueError( + f"Batch job {self.batch_id} is not completed. Status: {batch.status}" # type: ignore[misc] + ) + + if not batch.output_file_id: # type: ignore[misc] + raise ValueError(f"Batch job {self.batch_id} has no output file") + + # Download and parse results + result_content = self.client.files.content(batch.output_file_id).content # type: ignore[misc] + return self._parse_results(result_content) + + async def aget_results(self) -> t.List[BatchResponse]: + """Asynchronously retrieve and parse batch job results.""" + if not self._is_async: + raise RuntimeError("Use get_results() for sync clients") + + batch = await self.client.batches.retrieve(self.batch_id) # type: ignore[misc] + + if batch.status != "completed": # type: ignore[misc] + raise ValueError( + f"Batch job {self.batch_id} is not completed. Status: {batch.status}" # type: ignore[misc] + ) + + if not batch.output_file_id: # type: ignore[misc] + raise ValueError(f"Batch job {self.batch_id} has no output file") + + # Download and parse results + result_content = await self.client.files.content(batch.output_file_id) # type: ignore[misc] + return self._parse_results(result_content.content) + + def _parse_results(self, content: bytes) -> t.List[BatchResponse]: + """Parse batch results from JSONL content.""" + results = [] + for line in content.decode("utf-8").strip().split("\n"): + if line.strip(): + result_data = json.loads(line) + results.append( + BatchResponse( + id=result_data["id"], + custom_id=result_data["custom_id"], + response=result_data.get("response"), + error=result_data.get("error"), + ) + ) + return results + + +class OpenAIBatchAPI: + """OpenAI Batch API client wrapper.""" + + def __init__( + self, + client: t.Union[OpenAI, AsyncOpenAI], + max_batch_size: int = 50000, + max_file_size_mb: int = 100, + ): + self.client = client + self.max_batch_size = max_batch_size + self.max_file_size_mb = max_file_size_mb + self._is_async = self._check_async_client(client) + + def _check_async_client(self, client: t.Any) -> bool: + """Check if the client is async.""" + return hasattr(client, "__aenter__") or "Async" in client.__class__.__name__ + + def _create_jsonl_content(self, requests: t.List[BatchRequest]) -> str: + """Create JSONL content from batch requests.""" + lines = [] + for request in requests: + line = json.dumps( + { + "custom_id": request.custom_id, + "method": request.method, + "url": request.url, + "body": request.body, + } + ) + lines.append(line) + return "\n".join(lines) + + def _validate_requests(self, requests: t.List[BatchRequest]) -> None: + """Validate batch requests.""" + if len(requests) > self.max_batch_size: + raise ValueError( + f"Batch size {len(requests)} exceeds maximum {self.max_batch_size}" + ) + + # Check for duplicate custom_ids + custom_ids = [req.custom_id for req in requests] + if len(custom_ids) != len(set(custom_ids)): + raise ValueError("Duplicate custom_id values found in batch requests") + + # Estimate file size (rough approximation) + jsonl_content = self._create_jsonl_content(requests) + size_mb = len(jsonl_content.encode("utf-8")) / (1024 * 1024) + if size_mb > self.max_file_size_mb: + raise ValueError( + f"Batch file size {size_mb:.2f}MB exceeds maximum {self.max_file_size_mb}MB" + ) + + def create_batch( + self, + requests: t.List[BatchRequest], + endpoint: str = BatchEndpoint.CHAT_COMPLETIONS.value, + completion_window: str = "24h", + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> BatchJob: + """Create a new batch job.""" + if self._is_async: + raise RuntimeError("Use acreate_batch() for async clients") + + self._validate_requests(requests) + + # Create JSONL file + jsonl_content = self._create_jsonl_content(requests) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + f.write(jsonl_content) + temp_file_path = f.name + + try: + # Upload file + with open(temp_file_path, "rb") as f: + batch_file = self.client.files.create(file=f, purpose="batch") # type: ignore[misc] + + # Create batch job + batch_job = self.client.batches.create( # type: ignore[misc] + input_file_id=batch_file.id, # type: ignore[misc] + endpoint=endpoint, # type: ignore[arg-type] + completion_window=completion_window, # type: ignore[arg-type] + metadata=metadata or {}, + ) + + return BatchJob( + client=self.client, + batch_id=batch_job.id, # type: ignore[misc] + endpoint=endpoint, # type: ignore[arg-type] + completion_window=completion_window, # type: ignore[arg-type] + ) + + finally: + # Clean up temp file + Path(temp_file_path).unlink(missing_ok=True) + + async def acreate_batch( + self, + requests: t.List[BatchRequest], + endpoint: str = BatchEndpoint.CHAT_COMPLETIONS.value, + completion_window: str = "24h", + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> BatchJob: + """Asynchronously create a new batch job.""" + if not self._is_async: + raise RuntimeError("Use create_batch() for sync clients") + + self._validate_requests(requests) + + # Create JSONL file + jsonl_content = self._create_jsonl_content(requests) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + f.write(jsonl_content) + temp_file_path = f.name + + try: + # Upload file + with open(temp_file_path, "rb") as f: + batch_file = await self.client.files.create(file=f, purpose="batch") # type: ignore[misc] + + # Create batch job + batch_job = await self.client.batches.create( # type: ignore[misc] + input_file_id=batch_file.id, # type: ignore[misc] + endpoint=endpoint, # type: ignore[arg-type] + completion_window=completion_window, # type: ignore[arg-type] + metadata=metadata or {}, + ) + + return BatchJob( + client=self.client, + batch_id=batch_job.id, # type: ignore[misc] + endpoint=endpoint, # type: ignore[arg-type] + completion_window=completion_window, # type: ignore[arg-type] + ) + + finally: + # Clean up temp file + Path(temp_file_path).unlink(missing_ok=True) + + def create_chat_completion_requests( + self, + prompts: t.List[t.Dict[str, t.Any]], + model: str, + **kwargs: t.Any, + ) -> t.List[BatchRequest]: + """Create batch requests for chat completions.""" + requests = [] + for i, prompt_data in enumerate(prompts): + request = BatchRequest( + custom_id=f"request-{i}", + url=BatchEndpoint.CHAT_COMPLETIONS.value, + body={ + "model": model, + "messages": prompt_data.get("messages", []), + **kwargs, + }, + ) + # Allow custom_id override + if "custom_id" in prompt_data: + request.custom_id = prompt_data["custom_id"] + + requests.append(request) + + return requests + + +def create_batch_api(client: t.Union[OpenAI, AsyncOpenAI]) -> OpenAIBatchAPI: + """Factory function to create OpenAI Batch API instance.""" + return OpenAIBatchAPI(client) diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 32d474c815..36e2b2ae49 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -7,7 +7,7 @@ import numpy as np from pydantic import BaseModel, Field -from ragas.dataset_schema import SingleTurnSample +from ragas.dataset_schema import MultiTurnSample, SingleTurnSample from ragas.metrics.base import ( MetricOutputType, MetricType, @@ -18,6 +18,9 @@ if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks + from langchain_core.prompt_values import PromptValue + + from ragas.llms.batch_api import BatchJob, BatchResponse logger = logging.getLogger(__name__) @@ -193,6 +196,91 @@ def _compute_score(self, answers: NLIStatementOutput): return score + def _samples_to_prompts( + self, samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]] + ) -> t.List["PromptValue"]: + """ + Convert samples to PromptValue objects for batch processing. + + For Faithfulness metric, this implementation focuses on the statement generation step + as it's the most computationally expensive part. The NLI verification step would + need the statements generated from this batch job and would be handled in a + separate batch job or through regular evaluation. + + Note: This handles only the statement generation phase of faithfulness evaluation. + The complete faithfulness score requires a two-step process where NLI verification + follows statement generation. + """ + prompts = [] + for i, sample in enumerate(samples): + if not isinstance(sample, SingleTurnSample): + raise ValueError( + "Faithfulness metric only supports single-turn samples" + ) + + # Convert sample to dict for processing + sample_dict = sample.model_dump() + + # Create statement generation prompt using the actual PydanticPrompt + statement_prompt_input = StatementGeneratorInput( + question=sample_dict["user_input"], answer=sample_dict["response"] + ) + + # Convert the PydanticPrompt to PromptValue + # For now, we use a simplified conversion - in a real implementation + # this would use the actual prompt conversion method from PydanticPrompt + from langchain_core.prompts import ChatPromptTemplate + + # Create a simple prompt template as fallback + simple_prompt = ChatPromptTemplate.from_messages( + [ + ( + "user", + f"Question: {statement_prompt_input.question}\nAnswer: {statement_prompt_input.answer}\n\nGenerate individual factual statements from this answer.", + ) + ] + ) + prompt_value = simple_prompt.format_prompt() + + prompts.append(prompt_value) + + return prompts + + def create_complete_batch_evaluation_job( + self, + samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]], + batch_size: t.Optional[int] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "CompleteFaithfulnessBatchJob": + """ + Create a complete batch evaluation job that handles both statement generation + and NLI verification steps for faithfulness evaluation. + + This method returns a specialized batch job that orchestrates the two-step + faithfulness evaluation process in batch mode. + """ + if not self.supports_batch_evaluation(): + raise ValueError( + f"Metric '{self.name}' does not support batch evaluation. " + "Ensure the LLM supports batch API operations." + ) + + if batch_size is None: + batch_size = 1000 + + if len(samples) > batch_size: + raise ValueError( + f"Sample count {len(samples)} exceeds maximum batch size {batch_size}. " + "Consider splitting into smaller batches." + ) + + return CompleteFaithfulnessBatchJob( + faithfulness_metric=self, + samples=samples, + batch_size=batch_size, + metadata=metadata or {}, + ) + async def _single_turn_ascore( self, sample: SingleTurnSample, callbacks: Callbacks ) -> float: @@ -273,4 +361,344 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: return sum(scores) / len(scores) +class CompleteFaithfulnessBatchJob: + """ + A specialized batch job for complete faithfulness evaluation that handles + both statement generation and NLI verification in sequence. + """ + + def __init__( + self, + faithfulness_metric: Faithfulness, + samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]], + batch_size: int = 1000, + metadata: t.Optional[t.Dict[str, str]] = None, + ): + self.faithfulness_metric = faithfulness_metric + self.samples = samples + self.batch_size = batch_size + self.metadata = metadata or {} + self.statement_job: t.Optional["BatchJob"] = None + self.nli_job: t.Optional["BatchJob"] = None + + def execute(self) -> t.List[float]: + """ + Execute the complete faithfulness evaluation in batch mode. + + Returns: + List of faithfulness scores for each sample + """ + # Step 1: Generate statements using batch API + statement_prompts = self.faithfulness_metric._samples_to_prompts(self.samples) + + if self.faithfulness_metric.llm is None: + raise ValueError("Faithfulness metric has no LLM configured") + + self.statement_job = self.faithfulness_metric.llm.create_batch_job( + prompts=statement_prompts, + metadata={ + **self.metadata, + "step": "statement_generation", + "metric": "faithfulness", + }, + ) + + # Wait for statement generation to complete + status = self.statement_job.wait_for_completion() + if status.value != "completed": + raise RuntimeError(f"Statement generation batch job failed: {status.value}") + + # Get statement generation results + statement_responses = self.statement_job.get_results() + + # Parse statements from responses + all_statements = self._parse_statement_responses(statement_responses) + + # Step 2: Create NLI verification prompts + nli_prompts = self._create_nli_prompts(all_statements) + + if nli_prompts: # Only create NLI job if there are statements to verify + self.nli_job = self.faithfulness_metric.llm.create_batch_job( + prompts=nli_prompts, + metadata={ + **self.metadata, + "step": "nli_verification", + "metric": "faithfulness", + }, + ) + + # Wait for NLI verification to complete + nli_status = self.nli_job.wait_for_completion() + if nli_status.value != "completed": + raise RuntimeError( + f"NLI verification batch job failed: {nli_status.value}" + ) + + # Get NLI results and compute final scores + nli_responses = self.nli_job.get_results() + return self._compute_final_scores(all_statements, nli_responses) + else: + # No statements were generated, return NaN scores + return [np.nan] * len(self.samples) + + async def aexecute(self) -> t.List[float]: + """Async version of execute.""" + # Step 1: Generate statements using batch API + statement_prompts = self.faithfulness_metric._samples_to_prompts(self.samples) + + if self.faithfulness_metric.llm is None: + raise ValueError("Faithfulness metric has no LLM configured") + + self.statement_job = await self.faithfulness_metric.llm.acreate_batch_job( + prompts=statement_prompts, + metadata={ + **self.metadata, + "step": "statement_generation", + "metric": "faithfulness", + }, + ) + + # Wait for statement generation to complete + status = await self.statement_job.await_completion() + if status.value != "completed": + raise RuntimeError(f"Statement generation batch job failed: {status.value}") + + # Get statement generation results + statement_responses = await self.statement_job.aget_results() + + # Parse statements from responses + all_statements = self._parse_statement_responses(statement_responses) + + # Step 2: Create NLI verification prompts + nli_prompts = self._create_nli_prompts(all_statements) + + if nli_prompts: # Only create NLI job if there are statements to verify + self.nli_job = await self.faithfulness_metric.llm.acreate_batch_job( + prompts=nli_prompts, + metadata={ + **self.metadata, + "step": "nli_verification", + "metric": "faithfulness", + }, + ) + + # Wait for NLI verification to complete + nli_status = await self.nli_job.await_completion() + if nli_status.value != "completed": + raise RuntimeError( + f"NLI verification batch job failed: {nli_status.value}" + ) + + # Get NLI results and compute final scores + nli_responses = await self.nli_job.aget_results() + return self._compute_final_scores(all_statements, nli_responses) + else: + # No statements were generated, return NaN scores + return [np.nan] * len(self.samples) + + def _parse_statement_responses( + self, responses: t.List["BatchResponse"] + ) -> t.List[t.List[str]]: + """Parse statement generation responses to extract statements for each sample.""" + all_statements = [] + + for i, response in enumerate(responses): + statements = [] + + if response.error is None and response.response is not None: + try: + # Extract content from OpenAI response + content = self._extract_response_content(response.response) + + # Try to parse as StatementGeneratorOutput + statements = self._parse_statement_content(content) + + except Exception as e: + logger.warning(f"Failed to parse statements for sample {i}: {e}") + + all_statements.append(statements) + + return all_statements + + def _extract_response_content(self, response: t.Dict[str, t.Any]) -> str: + """Extract text content from OpenAI response.""" + try: + choices = response.get("choices", []) + if choices and len(choices) > 0: + message = choices[0].get("message", {}) + return message.get("content", "") + return "" + except Exception: + return "" + + def _parse_statement_content(self, content: str) -> t.List[str]: + """Parse statement generation content to extract individual statements.""" + try: + import json + import re + + # Clean content and try to parse as JSON + content = content.strip() + + # Look for JSON blocks + json_match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if json_match: + content = json_match.group(1) + elif content.startswith("{") and content.endswith("}"): + pass # Already clean JSON + else: + # Look for JSON object in text + json_match = re.search(r"\{[^{}]*\}", content) + if json_match: + content = json_match.group(0) + else: + return [] + + parsed = json.loads(content) + + # Extract statements from the parsed response + if "statements" in parsed and isinstance(parsed["statements"], list): + return [stmt for stmt in parsed["statements"] if isinstance(stmt, str)] + + return [] + + except (json.JSONDecodeError, Exception): + # Fallback: try to extract statements from text + lines = content.split("\n") + statements = [] + for line in lines: + line = line.strip() + if line and not line.startswith("#") and not line.startswith("*"): + # Remove common prefixes like "1.", "-", etc. + line = re.sub(r"^\d+\.\s*", "", line) + line = re.sub(r"^[-*]\s*", "", line) + if line: + statements.append(line) + return statements + + def _create_nli_prompts( + self, all_statements: t.List[t.List[str]] + ) -> t.List["PromptValue"]: + """Create NLI verification prompts for all statements.""" + prompts = [] + + for i, (sample, statements) in enumerate(zip(self.samples, all_statements)): + if not statements: # Skip if no statements were generated + continue + + if not isinstance(sample, SingleTurnSample): + continue # Skip non-single-turn samples + + # Get context for this sample + contexts = sample.retrieved_contexts or [] + contexts_str = "\n".join(contexts) + + # Create NLI prompt input + nli_input = NLIStatementInput(context=contexts_str, statements=statements) + + # Convert to prompt value - simplified implementation + from langchain_core.prompts import ChatPromptTemplate + + nli_prompt = ChatPromptTemplate.from_messages( + [ + ( + "user", + f"Context: {nli_input.context}\n\nStatements to verify:\n" + + "\n".join([f"- {stmt}" for stmt in nli_input.statements]) + + "\n\nFor each statement, determine if it can be directly inferred from the context. Return 1 for true, 0 for false.", + ) + ] + ) + prompt_value = nli_prompt.format_prompt() + + prompts.append(prompt_value) + + return prompts + + def _compute_final_scores( + self, + all_statements: t.List[t.List[str]], + nli_responses: t.List["BatchResponse"], + ) -> t.List[float]: + """Compute final faithfulness scores from NLI responses.""" + scores = [] + response_idx = 0 + + for statements in all_statements: + if not statements: + # No statements generated for this sample + scores.append(np.nan) + continue + + if response_idx >= len(nli_responses): + # No response available + scores.append(np.nan) + continue + + response = nli_responses[response_idx] + response_idx += 1 + + if response.error is not None or response.response is None: + scores.append(np.nan) + continue + + try: + # Parse NLI response + content = self._extract_response_content(response.response) + verdicts = self._parse_nli_content(content) + + if verdicts: + # Compute faithfulness score as average of verdicts + score = sum(verdicts) / len(verdicts) + scores.append(score) + else: + scores.append(np.nan) + + except Exception as e: + logger.warning(f"Failed to parse NLI response: {e}") + scores.append(np.nan) + + return scores + + def _parse_nli_content(self, content: str) -> t.List[float]: + """Parse NLI verification content to extract verdict scores.""" + try: + import json + import re + + # Clean content and try to parse as JSON + content = content.strip() + + # Look for JSON blocks + json_match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if json_match: + content = json_match.group(1) + elif content.startswith("{") and content.endswith("}"): + pass # Already clean JSON + else: + # Look for JSON object in text + json_match = re.search(r"\{[^{}]*\}", content) + if json_match: + content = json_match.group(0) + else: + return [] + + parsed = json.loads(content) + + # Extract verdicts from statements + if "statements" in parsed and isinstance(parsed["statements"], list): + verdicts = [] + for stmt in parsed["statements"]: + if isinstance(stmt, dict) and "verdict" in stmt: + verdict = stmt["verdict"] + if isinstance(verdict, (int, float)): + verdicts.append(float(verdict)) + return verdicts + + return [] + + except (json.JSONDecodeError, Exception): + return [] + + faithfulness = Faithfulness() diff --git a/src/ragas/metrics/base.py b/src/ragas/metrics/base.py index b3c4bc7bad..8a9e7d661f 100644 --- a/src/ragas/metrics/base.py +++ b/src/ragas/metrics/base.py @@ -22,10 +22,12 @@ if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks + from langchain_core.prompt_values import PromptValue from ragas.config import DemonstrationConfig, InstructionConfig from ragas.embeddings import BaseRagasEmbedding, BaseRagasEmbeddings from ragas.llms import BaseRagasLLM + from ragas.llms.batch_api import BatchJob logger = logging.getLogger(__name__) @@ -228,6 +230,111 @@ def init(self, run_config: RunConfig): ) self.llm.set_run_config(run_config) + def supports_batch_evaluation(self) -> bool: + """Check if this metric supports batch evaluation via OpenAI Batch API.""" + return self.llm is not None and self.llm.supports_batch_api() + + def create_batch_evaluation_job( + self, + samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]], + batch_size: t.Optional[int] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "BatchJob": + """ + Create a batch job for evaluating multiple samples. + + Args: + samples: List of samples to evaluate + batch_size: Maximum batch size (defaults to 1000) + metadata: Optional metadata for the batch job + + Returns: + BatchJob instance for monitoring and retrieving results + + Raises: + ValueError: If batch evaluation is not supported + """ + if not self.supports_batch_evaluation(): + raise ValueError( + f"Metric '{self.name}' does not support batch evaluation. " + "Ensure the LLM supports batch API operations." + ) + + if batch_size is None: + batch_size = 1000 + + # Split samples into batches if needed + if len(samples) > batch_size: + raise ValueError( + f"Sample count {len(samples)} exceeds maximum batch size {batch_size}. " + "Consider splitting into smaller batches." + ) + + # Convert samples to prompts for batch processing + prompts = self._samples_to_prompts(samples) + + if self.llm is None: + raise ValueError(f"Metric '{self.name}' has no LLM configured") + + return self.llm.create_batch_job( + prompts=prompts, + metadata={ + **(metadata or {}), + "metric_name": self.name, + "sample_count": str(len(samples)), + }, + ) + + async def acreate_batch_evaluation_job( + self, + samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]], + batch_size: t.Optional[int] = None, + metadata: t.Optional[t.Dict[str, str]] = None, + ) -> "BatchJob": + """Async version of create_batch_evaluation_job.""" + if not self.supports_batch_evaluation(): + raise ValueError( + f"Metric '{self.name}' does not support batch evaluation. " + "Ensure the LLM supports batch API operations." + ) + + if batch_size is None: + batch_size = 1000 + + if len(samples) > batch_size: + raise ValueError( + f"Sample count {len(samples)} exceeds maximum batch size {batch_size}. " + "Consider splitting into smaller batches." + ) + + prompts = self._samples_to_prompts(samples) + + if self.llm is None: + raise ValueError(f"Metric '{self.name}' has no LLM configured") + + return await self.llm.acreate_batch_job( + prompts=prompts, + metadata={ + **(metadata or {}), + "metric_name": self.name, + "sample_count": str(len(samples)), + }, + ) + + def _samples_to_prompts( + self, samples: t.List[t.Union[SingleTurnSample, MultiTurnSample]] + ) -> t.List["PromptValue"]: + """ + Convert samples to PromptValue objects for batch processing. + + This method should be overridden by specific metrics to customize + how samples are converted to prompts. + """ + raise NotImplementedError( + f"Metric '{self.name}' must implement _samples_to_prompts method " + "to support batch evaluation." + ) + def _optimize_instruction( self, instruction_config: InstructionConfig, diff --git a/tests/unit/llms/test_batch_api.py b/tests/unit/llms/test_batch_api.py new file mode 100644 index 0000000000..aa4d4fa72d --- /dev/null +++ b/tests/unit/llms/test_batch_api.py @@ -0,0 +1,283 @@ +""" +Unit tests for OpenAI Batch API functionality. +""" + +import json +from unittest.mock import AsyncMock, Mock, mock_open, patch + +import pytest + +from ragas.llms.batch_api import ( + BatchEndpoint, + BatchJob, + BatchRequest, + BatchResponse, + BatchStatus, + OpenAIBatchAPI, + create_batch_api, +) + + +class TestBatchRequest: + """Test BatchRequest dataclass.""" + + def test_batch_request_creation(self): + """Test creating a batch request.""" + request = BatchRequest( + custom_id="test-1", + method="POST", + url=BatchEndpoint.CHAT_COMPLETIONS.value, + body={ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + + assert request.custom_id == "test-1" + assert request.method == "POST" + assert request.url == "/v1/chat/completions" + assert request.body["model"] == "gpt-4o-mini" + + +class TestBatchResponse: + """Test BatchResponse dataclass.""" + + def test_batch_response_creation(self): + """Test creating a batch response.""" + response = BatchResponse( + id="batch-123", + custom_id="test-1", + response={"choices": [{"message": {"content": "Hello!"}}]}, + error=None, + ) + + assert response.id == "batch-123" + assert response.custom_id == "test-1" + assert response.response is not None + assert response.error is None + + +class TestBatchJob: + """Test BatchJob class.""" + + def test_batch_job_sync_client(self): + """Test BatchJob with sync client.""" + mock_client = Mock() + mock_client.__class__.__name__ = "OpenAI" + + job = BatchJob( + client=mock_client, + batch_id="batch-123", + endpoint=BatchEndpoint.CHAT_COMPLETIONS.value, + ) + + assert job.client == mock_client + assert job.batch_id == "batch-123" + assert job.endpoint == "/v1/chat/completions" + assert not job._is_async + + def test_batch_job_async_client(self): + """Test BatchJob with async client.""" + mock_client = Mock() + mock_client.__class__.__name__ = "AsyncOpenAI" + + job = BatchJob(client=mock_client, batch_id="batch-123") + + assert job._is_async + + def test_get_status_sync(self): + """Test getting batch status synchronously.""" + mock_client = Mock() + mock_client.__class__.__name__ = "OpenAI" + mock_batch = Mock() + mock_batch.status = "completed" + mock_client.batches.retrieve.return_value = mock_batch + + job = BatchJob(client=mock_client, batch_id="batch-123") + status = job.get_status() + + assert status == BatchStatus.COMPLETED + mock_client.batches.retrieve.assert_called_once_with("batch-123") + + @pytest.mark.asyncio + async def test_get_status_async(self): + """Test getting batch status asynchronously.""" + mock_client = AsyncMock() + mock_client.__class__.__name__ = "AsyncOpenAI" + mock_batch = Mock() + mock_batch.status = "in_progress" + mock_client.batches.retrieve.return_value = mock_batch + + job = BatchJob(client=mock_client, batch_id="batch-123") + status = await job.aget_status() + + assert status == BatchStatus.IN_PROGRESS + mock_client.batches.retrieve.assert_called_once_with("batch-123") + + def test_parse_results(self): + """Test parsing batch results from JSONL content.""" + mock_client = Mock() + mock_client.__class__.__name__ = "OpenAI" + + job = BatchJob(client=mock_client, batch_id="batch-123") + + # Mock JSONL content + jsonl_content = """{"id": "batch-123", "custom_id": "req-1", "response": {"choices": [{"message": {"content": "Hello"}}]}} +{"id": "batch-123", "custom_id": "req-2", "error": {"message": "Rate limit exceeded"}}""" + + results = job._parse_results(jsonl_content.encode("utf-8")) + + assert len(results) == 2 + assert results[0].custom_id == "req-1" + assert results[0].response is not None + assert results[0].error is None + assert results[1].custom_id == "req-2" + assert results[1].response is None + assert results[1].error is not None + + +class TestOpenAIBatchAPI: + """Test OpenAIBatchAPI class.""" + + def test_batch_api_creation(self): + """Test creating batch API instance.""" + mock_client = Mock() + mock_client.__class__.__name__ = "OpenAI" + + api = OpenAIBatchAPI(client=mock_client) + + assert api.client == mock_client + assert api.max_batch_size == 50000 + assert api.max_file_size_mb == 100 + assert not api._is_async + + def test_create_jsonl_content(self): + """Test creating JSONL content from requests.""" + mock_client = Mock() + api = OpenAIBatchAPI(client=mock_client) + + requests = [ + BatchRequest( + custom_id="req-1", + body={ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello"}], + }, + ), + BatchRequest( + custom_id="req-2", + body={ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hi"}], + }, + ), + ] + + jsonl_content = api._create_jsonl_content(requests) + lines = jsonl_content.split("\n") + + assert len(lines) == 2 + req1 = json.loads(lines[0]) + assert req1["custom_id"] == "req-1" + assert req1["method"] == "POST" + assert req1["url"] == "/v1/chat/completions" + + def test_validate_requests_success(self): + """Test successful request validation.""" + mock_client = Mock() + api = OpenAIBatchAPI(client=mock_client, max_batch_size=10) + + requests = [ + BatchRequest(custom_id="req-1", body={"model": "gpt-4o-mini"}), + BatchRequest(custom_id="req-2", body={"model": "gpt-4o-mini"}), + ] + + # Should not raise any exception + api._validate_requests(requests) + + def test_validate_requests_too_many(self): + """Test validation failure for too many requests.""" + mock_client = Mock() + api = OpenAIBatchAPI(client=mock_client, max_batch_size=1) + + requests = [ + BatchRequest(custom_id="req-1", body={}), + BatchRequest(custom_id="req-2", body={}), + ] + + with pytest.raises(ValueError, match="Batch size 2 exceeds maximum 1"): + api._validate_requests(requests) + + def test_validate_requests_duplicate_ids(self): + """Test validation failure for duplicate custom IDs.""" + mock_client = Mock() + api = OpenAIBatchAPI(client=mock_client) + + requests = [ + BatchRequest(custom_id="req-1", body={}), + BatchRequest(custom_id="req-1", body={}), # Duplicate ID + ] + + with pytest.raises(ValueError, match="Duplicate custom_id values found"): + api._validate_requests(requests) + + def test_create_chat_completion_requests(self): + """Test creating chat completion requests.""" + mock_client = Mock() + api = OpenAIBatchAPI(client=mock_client) + + prompts = [ + {"messages": [{"role": "user", "content": "Hello"}]}, + {"messages": [{"role": "user", "content": "Hi"}], "custom_id": "custom-1"}, + ] + + requests = api.create_chat_completion_requests( + prompts=prompts, model="gpt-4o-mini", temperature=0.7 + ) + + assert len(requests) == 2 + assert requests[0].custom_id == "request-0" + assert requests[1].custom_id == "custom-1" + assert requests[0].body["model"] == "gpt-4o-mini" + assert requests[0].body["temperature"] == 0.7 + + @patch("tempfile.NamedTemporaryFile") + @patch("builtins.open", mock_open()) + @patch("pathlib.Path.unlink") + def test_create_batch_sync(self, mock_unlink, mock_temp_file): + """Test creating batch job synchronously.""" + # Setup mocks + mock_client = Mock() + mock_client.__class__.__name__ = "OpenAI" + + mock_temp_file_instance = Mock() + mock_temp_file_instance.name = "/tmp/test.jsonl" + mock_temp_file_instance.__enter__ = Mock(return_value=mock_temp_file_instance) + mock_temp_file_instance.__exit__ = Mock(return_value=None) + mock_temp_file.return_value = mock_temp_file_instance + + mock_file_obj = Mock() + mock_file_obj.id = "file-123" + mock_client.files.create.return_value = mock_file_obj + + mock_batch_job = Mock() + mock_batch_job.id = "batch-123" + mock_client.batches.create.return_value = mock_batch_job + + api = OpenAIBatchAPI(client=mock_client) + requests = [BatchRequest(custom_id="req-1", body={"model": "gpt-4o-mini"})] + + batch_job = api.create_batch(requests=requests) + + assert batch_job.batch_id == "batch-123" + mock_client.files.create.assert_called_once() + mock_client.batches.create.assert_called_once() + + +def test_create_batch_api(): + """Test factory function.""" + mock_client = Mock() + api = create_batch_api(mock_client) + + assert isinstance(api, OpenAIBatchAPI) + assert api.client == mock_client diff --git a/tests/unit/llms/test_batch_support.py b/tests/unit/llms/test_batch_support.py new file mode 100644 index 0000000000..e83a95bd4f --- /dev/null +++ b/tests/unit/llms/test_batch_support.py @@ -0,0 +1,197 @@ +""" +Unit tests for batch API support in LLM wrappers. +""" + +from typing import List, cast +from unittest.mock import Mock, patch + +import pytest +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.prompt_values import PromptValue +from langchain_openai import ChatOpenAI + +from ragas.llms.base import LangchainLLMWrapper +from ragas.llms.batch_api import BatchRequest + + +class MockPromptValue(PromptValue): + """Mock implementation of PromptValue for testing.""" + + def __init__(self, messages: List[BaseMessage]): + self._messages = messages + + def to_messages(self) -> List[BaseMessage]: + return self._messages + + def to_string(self) -> str: + return " ".join([str(msg.content) for msg in self._messages]) + + +class TestLangchainLLMWrapperBatchSupport: + """Test batch API support in LangchainLLMWrapper.""" + + def test_supports_batch_api_openai(self): + """Test batch API support detection for OpenAI models.""" + mock_openai_llm = Mock(spec=ChatOpenAI) + wrapper = LangchainLLMWrapper(mock_openai_llm) + + assert wrapper.supports_batch_api() + assert wrapper.batch_api_support + + def test_supports_batch_api_non_openai(self): + """Test batch API support detection for non-OpenAI models.""" + mock_other_llm = Mock() + mock_other_llm.__class__.__name__ = "ChatAnthropic" + wrapper = LangchainLLMWrapper(mock_other_llm) + + assert not wrapper.supports_batch_api() + assert not wrapper.batch_api_support + + @patch("ragas.llms.batch_api.create_batch_api") + def test_get_batch_api(self, mock_create_batch_api): + """Test getting batch API instance.""" + mock_openai_llm = Mock(spec=ChatOpenAI) + mock_openai_llm.client = Mock() + + mock_batch_api = Mock() + mock_create_batch_api.return_value = mock_batch_api + + wrapper = LangchainLLMWrapper(mock_openai_llm) + batch_api = wrapper._get_batch_api() + + assert batch_api == mock_batch_api + mock_create_batch_api.assert_called_once_with(mock_openai_llm.client) + + def test_get_batch_api_not_supported(self): + """Test getting batch API when not supported.""" + mock_other_llm = Mock() + wrapper = LangchainLLMWrapper(mock_other_llm) + + with pytest.raises(ValueError, match="Batch API not supported"): + wrapper._get_batch_api() + + def test_create_batch_requests_from_prompts(self): + """Test converting prompts to batch requests.""" + mock_openai_llm = Mock(spec=ChatOpenAI) + mock_openai_llm.model_name = "gpt-4o-mini" + + wrapper = LangchainLLMWrapper(mock_openai_llm) + + # Create proper prompt objects + prompt1 = MockPromptValue([HumanMessage(content="Hello, how are you?")]) + + prompt2 = MockPromptValue( + [ + SystemMessage(content="You are helpful"), + HumanMessage(content="What is AI?"), + ] + ) + + prompts = cast(List[PromptValue], [prompt1, prompt2]) + + requests = wrapper._create_batch_requests_from_prompts( + prompts=prompts, n=1, temperature=0.7, stop=["END"] + ) + + assert len(requests) == 2 + assert all(isinstance(req, BatchRequest) for req in requests) + + # Check first request + assert requests[0].custom_id == "ragas-batch-0" + assert requests[0].body["model"] == "gpt-4o-mini" + assert requests[0].body["temperature"] == 0.7 + assert requests[0].body["stop"] == ["END"] + assert len(requests[0].body["messages"]) == 1 + assert requests[0].body["messages"][0]["role"] == "human" + + # Check second request + assert requests[1].custom_id == "ragas-batch-1" + assert len(requests[1].body["messages"]) == 2 + assert requests[1].body["messages"][0]["role"] == "system" + assert requests[1].body["messages"][1]["role"] == "human" + + def test_create_batch_requests_string_fallback(self): + """Test batch request creation with string prompt fallback.""" + mock_openai_llm = Mock(spec=ChatOpenAI) + mock_openai_llm.model_name = "gpt-3.5-turbo" + + wrapper = LangchainLLMWrapper(mock_openai_llm) + + # Create prompt without to_messages method to trigger fallback + # Use a class that inherits from object and doesn't have to_messages + class StringOnlyPrompt: + def __str__(self): + return "What is Python?" + + prompt = StringOnlyPrompt() + + requests = wrapper._create_batch_requests_from_prompts( + cast(List[PromptValue], [prompt]) + ) + + assert len(requests) == 1 + assert requests[0].body["messages"][0]["role"] == "user" + assert requests[0].body["messages"][0]["content"] == "What is Python?" + + def test_create_batch_requests_bypass_temperature(self): + """Test batch request creation with temperature bypass.""" + mock_openai_llm = Mock(spec=ChatOpenAI) + mock_openai_llm.model_name = "gpt-4o" + + wrapper = LangchainLLMWrapper(mock_openai_llm, bypass_temperature=True) + + prompt = MockPromptValue([HumanMessage(content="Hello")]) + + requests = wrapper._create_batch_requests_from_prompts( + cast(List[PromptValue], [prompt]), temperature=0.5 + ) + + # Temperature should be removed due to bypass_temperature=True + assert "temperature" not in requests[0].body + + @patch.object(LangchainLLMWrapper, "_get_batch_api") + @patch.object(LangchainLLMWrapper, "_create_batch_requests_from_prompts") + def test_create_batch_job(self, mock_create_requests, mock_get_batch_api): + """Test creating batch job.""" + mock_openai_llm = Mock(spec=ChatOpenAI) + wrapper = LangchainLLMWrapper(mock_openai_llm) + + mock_batch_api = Mock() + mock_batch_job = Mock() + mock_batch_api.create_batch.return_value = mock_batch_job + mock_get_batch_api.return_value = mock_batch_api + + mock_requests = [Mock(), Mock()] + mock_create_requests.return_value = mock_requests + + prompts = cast( + List[PromptValue], + [ + MockPromptValue([HumanMessage(content="Test prompt 1")]), + MockPromptValue([HumanMessage(content="Test prompt 2")]), + ], + ) + metadata = {"test": "value"} + + result = wrapper.create_batch_job( + prompts=prompts, n=2, temperature=0.8, stop=["STOP"], metadata=metadata + ) + + assert result == mock_batch_job + mock_get_batch_api.assert_called_once() + mock_create_requests.assert_called_once_with(prompts, 2, 0.8, ["STOP"]) + mock_batch_api.create_batch.assert_called_once_with( + requests=mock_requests, metadata=metadata + ) + + def test_create_batch_job_not_supported(self): + """Test creating batch job when not supported.""" + mock_other_llm = Mock() + wrapper = LangchainLLMWrapper(mock_other_llm) + + with pytest.raises(ValueError, match="Batch API not supported"): + wrapper.create_batch_job( + cast( + List[PromptValue], [MockPromptValue([HumanMessage(content="test")])] + ) + )