diff --git a/cookbooks/zero_shot_evaluation/__main__.py b/cookbooks/zero_shot_evaluation/__main__.py new file mode 100644 index 00000000..b831abc1 --- /dev/null +++ b/cookbooks/zero_shot_evaluation/__main__.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +"""CLI entry point for zero-shot evaluation. + +Usage: + python -m cookbooks.zero_shot_evaluation --config config.yaml + python -m cookbooks.zero_shot_evaluation --config config.yaml --save + python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save +""" + +import asyncio +import json +from pathlib import Path +from typing import List, Optional + +import fire +from loguru import logger + +from cookbooks.zero_shot_evaluation.schema import GeneratedQuery, load_config +from cookbooks.zero_shot_evaluation.zero_shot_pipeline import ZeroShotPipeline + + +def _load_queries_from_file(queries_file: str) -> List[GeneratedQuery]: + """Load pre-generated queries from JSON file.""" + with open(queries_file, "r", encoding="utf-8") as f: + data = json.load(f) + queries = [GeneratedQuery(**item) for item in data] + logger.info(f"Loaded {len(queries)} queries from {queries_file}") + return queries + + +async def _run_evaluation( + config_path: str, + output_dir: Optional[str] = None, + queries_file: Optional[str] = None, + save: bool = False, + resume: bool = True, +) -> None: + """Run evaluation pipeline. + + Args: + config_path: Path to configuration file + output_dir: Output directory (overrides config) + queries_file: Path to pre-generated queries JSON file (skip generation) + save: Whether to save results to file + resume: Whether to resume from checkpoint + """ + config = load_config(config_path) + + if output_dir: + config.output.output_dir = output_dir + + # Load pre-generated queries if provided + queries = None + if queries_file: + queries = _load_queries_from_file(queries_file) + + pipeline = ZeroShotPipeline(config=config, resume=resume) + result = await pipeline.evaluate(queries=queries) + + if save: + pipeline.save_results(result, output_dir) + + +def main( + config: str, + output_dir: Optional[str] = None, + queries_file: Optional[str] = None, + save: bool = False, + fresh: bool = False, +) -> None: + """Zero-shot evaluation CLI with checkpoint support. + + Args: + config: Path to YAML configuration file + output_dir: Output directory for results + queries_file: Path to pre-generated queries JSON (skip query generation) + save: Whether to save results to file + fresh: Start fresh, ignore any existing checkpoint + + Examples: + # Normal run (auto-resumes from checkpoint) + python -m cookbooks.zero_shot_evaluation --config config.yaml --save + + # Use pre-generated queries + python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save + + # Start fresh, ignore checkpoint + python -m cookbooks.zero_shot_evaluation --config config.yaml --fresh --save + """ + config_path = Path(config) + if not config_path.exists(): + logger.error(f"Config file not found: {config}") + return + + if queries_file: + queries_path = Path(queries_file) + if not queries_path.exists(): + logger.error(f"Queries file not found: {queries_file}") + return + + logger.info(f"Starting zero-shot evaluation with config: {config}") + if queries_file: + logger.info(f"Using pre-generated queries from: {queries_file}") + if fresh: + logger.info("Starting fresh (ignoring checkpoint)") + else: + logger.info("Resume mode enabled (will continue from checkpoint if exists)") + + asyncio.run(_run_evaluation(str(config_path), output_dir, queries_file, save, resume=not fresh)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/cookbooks/zero_shot_evaluation/checkpoint.py b/cookbooks/zero_shot_evaluation/checkpoint.py new file mode 100644 index 00000000..43a8b7b3 --- /dev/null +++ b/cookbooks/zero_shot_evaluation/checkpoint.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +"""Checkpoint management for evaluation pipeline.""" + +import json +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field + +from cookbooks.zero_shot_evaluation.schema import GeneratedQuery + + +class EvaluationStage(str, Enum): + """Evaluation pipeline stages.""" + + NOT_STARTED = "not_started" + QUERIES_GENERATED = "queries_generated" + RESPONSES_COLLECTED = "responses_collected" + RUBRICS_GENERATED = "rubrics_generated" + EVALUATION_COMPLETE = "evaluation_complete" + + +class CheckpointData(BaseModel): + """Checkpoint data model.""" + + stage: EvaluationStage = Field(default=EvaluationStage.NOT_STARTED) + created_at: str = Field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = Field(default_factory=lambda: datetime.now().isoformat()) + + # Data files + queries_file: Optional[str] = None + responses_file: Optional[str] = None + rubrics_file: Optional[str] = None + + # Progress tracking + total_queries: int = 0 + collected_responses: int = 0 + evaluated_pairs: int = 0 + total_pairs: int = 0 + + +class CheckpointManager: + """Manage evaluation checkpoints for resume capability.""" + + CHECKPOINT_FILE = "checkpoint.json" + QUERIES_FILE = "queries.json" + RESPONSES_FILE = "responses.json" + RUBRICS_FILE = "rubrics.json" + + def __init__(self, output_dir: str): + """Initialize checkpoint manager. + + Args: + output_dir: Directory to store checkpoint files + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self._checkpoint: Optional[CheckpointData] = None + + @property + def checkpoint_path(self) -> Path: + return self.output_dir / self.CHECKPOINT_FILE + + def load(self) -> Optional[CheckpointData]: + """Load existing checkpoint if available.""" + if not self.checkpoint_path.exists(): + logger.info("No checkpoint found, starting fresh") + return None + + try: + with open(self.checkpoint_path, "r", encoding="utf-8") as f: + data = json.load(f) + self._checkpoint = CheckpointData(**data) + logger.info(f"Loaded checkpoint: stage={self._checkpoint.stage.value}") + return self._checkpoint + except Exception as e: + logger.warning(f"Failed to load checkpoint: {e}") + return None + + def save(self, checkpoint: CheckpointData) -> None: + """Save checkpoint to file.""" + checkpoint.updated_at = datetime.now().isoformat() + self._checkpoint = checkpoint + + with open(self.checkpoint_path, "w", encoding="utf-8") as f: + json.dump(checkpoint.model_dump(), f, indent=2, ensure_ascii=False) + + logger.debug(f"Checkpoint saved: stage={checkpoint.stage.value}") + + def save_queries(self, queries: List[GeneratedQuery]) -> str: + """Save generated queries.""" + file_path = self.output_dir / self.QUERIES_FILE + + with open(file_path, "w", encoding="utf-8") as f: + json.dump([q.model_dump() for q in queries], f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(queries)} queries to {file_path}") + return str(file_path) + + def load_queries(self) -> List[GeneratedQuery]: + """Load saved queries.""" + file_path = self.output_dir / self.QUERIES_FILE + + if not file_path.exists(): + return [] + + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + queries = [GeneratedQuery(**item) for item in data] + logger.info(f"Loaded {len(queries)} queries from {file_path}") + return queries + + def save_responses(self, responses: List[Dict[str, Any]]) -> str: + """Save collected responses.""" + file_path = self.output_dir / self.RESPONSES_FILE + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(responses, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(responses)} responses to {file_path}") + return str(file_path) + + def load_responses(self) -> List[Dict[str, Any]]: + """Load saved responses.""" + file_path = self.output_dir / self.RESPONSES_FILE + + if not file_path.exists(): + return [] + + with open(file_path, "r", encoding="utf-8") as f: + responses = json.load(f) + + logger.info(f"Loaded {len(responses)} responses from {file_path}") + return responses + + def save_rubrics(self, rubrics: List[str]) -> str: + """Save generated rubrics.""" + file_path = self.output_dir / self.RUBRICS_FILE + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(rubrics, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(rubrics)} rubrics to {file_path}") + return str(file_path) + + def load_rubrics(self) -> List[str]: + """Load saved rubrics.""" + file_path = self.output_dir / self.RUBRICS_FILE + + if not file_path.exists(): + return [] + + with open(file_path, "r", encoding="utf-8") as f: + rubrics = json.load(f) + + logger.info(f"Loaded {len(rubrics)} rubrics from {file_path}") + return rubrics + + def update_stage( + self, + stage: EvaluationStage, + **kwargs, + ) -> None: + """Update checkpoint stage and save.""" + if self._checkpoint is None: + self._checkpoint = CheckpointData() + + self._checkpoint.stage = stage + for key, value in kwargs.items(): + if hasattr(self._checkpoint, key): + setattr(self._checkpoint, key, value) + + self.save(self._checkpoint) + + def clear(self) -> None: + """Clear all checkpoint data.""" + for file_name in [ + self.CHECKPOINT_FILE, + self.QUERIES_FILE, + self.RESPONSES_FILE, + self.RUBRICS_FILE, + ]: + file_path = self.output_dir / file_name + if file_path.exists(): + file_path.unlink() + + self._checkpoint = None + logger.info("Checkpoint cleared") diff --git a/cookbooks/zero_shot_evaluation/query_generator.py b/cookbooks/zero_shot_evaluation/query_generator.py new file mode 100644 index 00000000..d9a8f78b --- /dev/null +++ b/cookbooks/zero_shot_evaluation/query_generator.py @@ -0,0 +1,699 @@ +# -*- coding: utf-8 -*- +"""Query generator for zero-shot evaluation with advanced optimization strategies. + +Features: +- Iterative generation with deduplication +- Evol-Instruct style complexity evolution +- Async parallel batch generation +""" + +import asyncio +import hashlib +from collections import defaultdict +from difflib import SequenceMatcher +from typing import Dict, List, Optional, Set, Tuple + +from loguru import logger +from tenacity import retry, stop_after_attempt, wait_exponential + +from cookbooks.zero_shot_evaluation.schema import ( + GeneratedQuery, + OpenAIEndpoint, + QueryGenerationConfig, + QueryGenerationOutput, + TaskConfig, +) +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.oai.message import ChatMessage +from openjudge.models.schema.prompt_template import PromptTemplate + +# ============================================================================= +# Prompt Templates +# ============================================================================= + +QUERY_GENERATION_PROMPT = """# Task +Based on the task description, generate diverse and representative test queries. + +## Task Description +{task_description} + +## Scenario +{scenario} + +## Seed Queries (for reference) +{seed_queries} + +## Category Distribution +{categories} + +## Already Generated Queries (AVOID similar ones) +{existing_queries} + +## Diversity Requirements +- Cover different query lengths (short: <20 words, medium: 20-50 words, long: >50 words) +- Include various difficulty levels (easy, medium, hard) +- Vary question types (factual, analytical, creative, edge cases) +- Include both common scenarios and edge cases +- AVOID semantically similar or redundant queries to existing ones + +## Anti-patterns to AVOID +- Don't generate queries too similar to existing ones listed above +- Don't generate overly generic queries like "Tell me about X" +- Don't repeat the same query structure with minor word changes +- Don't use template-like patterns repeatedly + +## Requirements +- Generate exactly {num_queries} test queries +- Each query should be independent and self-contained +- Batch ID: {batch_id} (use this to vary your generation strategy) + +## Output Format +Return a JSON object with: +- queries: list of objects, each with "query" (required), "category" (optional), "difficulty" (optional) +- reason: brief explanation of generation strategy + +Example: +{{ + "queries": [ + {{"query": "How does X handle Y in scenario Z?", "category": "technical", "difficulty": "medium"}}, + {{"query": "What happens when...", "category": "edge_case", "difficulty": "hard"}} + ], + "reason": "Generated queries covering different aspects..." +}} +""" + +EVOLUTION_PROMPT = """# Task +Evolve the given query into more complex versions using the specified strategy. + +## Original Query +{original_query} + +## Evolution Strategy: {strategy} + +### Strategy Descriptions: +- constraints: Add specific constraints (time, scope, conditions, limitations) +- reasoning: Require multi-step reasoning or comparison +- edge_cases: Add edge cases, exceptions, or unusual conditions +- combination: Combine with related concepts or cross-domain knowledge + +## Requirements +- Generate {num_variations} evolved versions +- Each version should be more challenging than the original +- Maintain the core intent while increasing complexity +- Evolved queries should be natural and realistic + +## Output Format +Return a JSON object with: +- evolved_queries: list of objects with "query", "difficulty", "evolution_type" +- reasoning: explanation of how complexity was increased + +Example: +{{ + "evolved_queries": [ + {{"query": "...", "difficulty": "hard", "evolution_type": "constraints"}}, + {{"query": "...", "difficulty": "hard", "evolution_type": "reasoning"}} + ], + "reasoning": "Added time constraints and multi-step reasoning..." +}} +""" + +QUERY_GENERATION_TEMPLATE = PromptTemplate( + messages=[ + ChatMessage( + role="system", + content="You are an expert at generating diverse and representative test queries for AI evaluation. " + "You excel at creating queries that cover various difficulty levels, categories, and edge cases. " + "You MUST avoid generating duplicate or semantically similar queries.", + ), + ChatMessage(role="user", content=QUERY_GENERATION_PROMPT), + ], +) + +EVOLUTION_TEMPLATE = PromptTemplate( + messages=[ + ChatMessage( + role="system", + content="You are an expert at evolving simple queries into more complex, challenging versions. " + "You apply the Evol-Instruct methodology to increase query complexity while maintaining naturalness.", + ), + ChatMessage(role="user", content=EVOLUTION_PROMPT), + ], +) + + +# ============================================================================= +# Evolution Output Schema +# ============================================================================= + + +from pydantic import BaseModel, Field + + +class EvolvedQuery(BaseModel): + """Single evolved query.""" + + query: str = Field(..., description="The evolved query text") + difficulty: str = Field(default="hard", description="Difficulty level") + evolution_type: str = Field(default="", description="Type of evolution applied") + + +class EvolutionOutput(BaseModel): + """Output schema for query evolution.""" + + evolved_queries: List[EvolvedQuery] = Field(..., description="List of evolved queries") + reasoning: str = Field(default="", description="Evolution reasoning") + + +# ============================================================================= +# Query Deduplicator +# ============================================================================= + + +class QueryDeduplicator: + """Handles query deduplication using multiple strategies.""" + + def __init__(self, max_similarity: float = 0.85): + self.max_similarity = max_similarity + self._seen_hashes: Set[str] = set() + self._seen_queries: List[str] = [] + + def _normalize(self, text: str) -> str: + """Normalize text for comparison.""" + return " ".join(text.lower().strip().split()) + + def _hash(self, text: str) -> str: + """Create hash for exact deduplication.""" + normalized = self._normalize(text) + return hashlib.md5(normalized.encode()).hexdigest() + + def _similarity(self, text1: str, text2: str) -> float: + """Calculate similarity ratio between two texts.""" + return SequenceMatcher(None, self._normalize(text1), self._normalize(text2)).ratio() + + def is_duplicate(self, query: str) -> bool: + """Check if query is a duplicate.""" + query_hash = self._hash(query) + + # Exact duplicate check + if query_hash in self._seen_hashes: + return True + + # Semantic similarity check (against recent queries for efficiency) + check_against = self._seen_queries[-100:] if len(self._seen_queries) > 100 else self._seen_queries + for seen in check_against: + if self._similarity(query, seen) > self.max_similarity: + return True + + return False + + def add(self, query: str) -> bool: + """Add query if not duplicate. Returns True if added.""" + if self.is_duplicate(query): + return False + + self._seen_hashes.add(self._hash(query)) + self._seen_queries.append(query) + return True + + def get_existing_summary(self, max_items: int = 10) -> str: + """Get summary of existing queries for prompt context.""" + if not self._seen_queries: + return "None yet" + + # Sample from existing queries + sample = self._seen_queries[:max_items] + return "\n".join(f"- {q[:100]}..." if len(q) > 100 else f"- {q}" for q in sample) + + +# ============================================================================= +# Query Validator +# ============================================================================= + + +class QueryValidator: + """Validate generated queries for quality.""" + + MIN_LENGTH = 5 + MAX_LENGTH = 1000 + + @classmethod + def validate(cls, query: GeneratedQuery) -> Tuple[bool, str]: + """Validate a single query. Returns (is_valid, reason).""" + text = query.query.strip() + + if len(text) < cls.MIN_LENGTH: + return False, f"Too short: {len(text)} chars" + + if len(text) > cls.MAX_LENGTH: + return False, f"Too long: {len(text)} chars" + + # Check for placeholder patterns + placeholders = ["[", "]", "{", "}", "...", "___", "XXX"] + for p in placeholders: + if p in text and text.count(p) > 2: + return False, f"Contains placeholder pattern: {p}" + + return True, "OK" + + +# ============================================================================= +# Query Generator +# ============================================================================= + + +class QueryGenerator: + """Generate test queries with advanced optimization strategies. + + Features: + - Iterative batch generation with deduplication + - Evol-Instruct style complexity evolution + - Async parallel generation for efficiency + - Configurable endpoint for query generation + """ + + def __init__( + self, + judge_endpoint: OpenAIEndpoint, + task_config: TaskConfig, + query_config: Optional[QueryGenerationConfig] = None, + ): + """Initialize QueryGenerator. + + Args: + judge_endpoint: OpenAI-compatible endpoint (fallback if query_config.endpoint not set) + task_config: Task configuration + query_config: Query generation configuration (including optional custom endpoint) + """ + self.task_config = task_config + self.query_config = query_config or QueryGenerationConfig() + + # Initialize deduplicator + self.deduplicator = QueryDeduplicator(max_similarity=self.query_config.max_similarity) + + # Determine which endpoint to use: custom endpoint in query_config, or fallback to judge_endpoint + if self.query_config.endpoint: + # Use custom endpoint specified in query_generation config + endpoint = self.query_config.endpoint + extra_params = endpoint.extra_params or {} + logger.info(f"Using custom query generation endpoint: {endpoint.model} @ {endpoint.base_url}") + else: + # Fallback to judge_endpoint + endpoint = judge_endpoint + extra_params = endpoint.extra_params or {} + logger.info(f"Using judge endpoint for query generation: {endpoint.model} @ {endpoint.base_url}") + + extra_params = dict(extra_params) # Make a copy to avoid modifying original + # Remove params that we'll set explicitly to avoid conflicts + extra_params.pop("stream", None) + extra_params.pop("temperature", None) + extra_params.pop("top_p", None) + + self.model = OpenAIChatModel( + model=endpoint.model, + api_key=endpoint.api_key, + base_url=endpoint.base_url, + stream=False, + temperature=self.query_config.temperature, + top_p=self.query_config.top_p, + **extra_params, + ) + + # ========================================================================= + # Main Generation Entry Point + # ========================================================================= + + async def generate(self, max_retries: int = 5) -> List[GeneratedQuery]: + """Generate test queries with all optimization strategies. + + Pipeline: + 1. Parallel batch generation (with retry until target count is reached) + 2. Deduplication + 3. Optional complexity evolution + 4. Final validation and filtering + + Args: + max_retries: Maximum number of retry rounds if not enough queries generated + + Returns: + List of GeneratedQuery objects + """ + target_count = self.query_config.num_queries + logger.info( + f"Starting query generation: target={target_count}, " + f"queries_per_call={self.query_config.queries_per_call}, " + f"parallel_batches={self.query_config.num_parallel_batches}, " + f"evolution={'enabled' if self.query_config.enable_evolution else 'disabled'}" + ) + + # Step 1: Parallel batch generation with retry until target count + base_queries: List[GeneratedQuery] = [] + retry_round = 0 + consecutive_failures = 0 + max_consecutive_failures = 3 # Stop after 3 consecutive complete failures + + while len(base_queries) < target_count and retry_round <= max_retries: + if retry_round > 0: + remaining = target_count - len(base_queries) + logger.info(f"Retry round {retry_round}: need {remaining} more queries") + + new_queries = await self._parallel_generate() + + # Deduplicate against existing queries + added_count = 0 + for q in new_queries: + if not self._is_duplicate(q, base_queries): + base_queries.append(q) + added_count += 1 + + logger.info(f"After round {retry_round}: {len(base_queries)} queries collected (+{added_count} new)") + + if len(new_queries) == 0: + consecutive_failures += 1 + logger.warning( + f"Round {retry_round} produced 0 queries (consecutive failures: {consecutive_failures}/{max_consecutive_failures})" + ) + if consecutive_failures >= max_consecutive_failures: + logger.error(f"Stopping after {max_consecutive_failures} consecutive complete failures") + break + else: + consecutive_failures = 0 # Reset on any success + + retry_round += 1 + + logger.info(f"Base generation complete: {len(base_queries)} queries") + + # Step 2: Optional complexity evolution + if self.query_config.enable_evolution and self.query_config.evolution_rounds > 0: + evolved_queries = await self._evolve_queries(base_queries) + base_queries.extend(evolved_queries) + logger.info(f"After evolution: {len(base_queries)} queries") + + # Step 3: Final deduplication and validation + final_queries = self._final_filter(base_queries) + logger.info(f"After final filtering: {len(final_queries)} queries") + + # Step 4: Trim to target count + result = final_queries[:target_count] + logger.info(f"Final result: {len(result)} queries (target: {target_count})") + + return result + + def _is_duplicate(self, query: GeneratedQuery, existing: List[GeneratedQuery]) -> bool: + """Check if a query is duplicate of existing queries (simple text comparison).""" + query_text = query.query.strip().lower() + for eq in existing: + if query_text == eq.query.strip().lower(): + return True + return False + + # ========================================================================= + # Parallel Batch Generation + # ========================================================================= + + async def _parallel_generate(self) -> List[GeneratedQuery]: + """Generate queries in parallel batches for better diversity.""" + num_batches = self.query_config.num_parallel_batches + queries_per_call = self.query_config.queries_per_call + + # Calculate target per batch, respecting queries_per_call limit + # Generate extra to account for deduplication + ideal_per_batch = (self.query_config.num_queries * 2) // num_batches + 1 + target_per_batch = min(ideal_per_batch, queries_per_call) + + logger.info( + f"Launching {num_batches} parallel generation batches, " + f"{target_per_batch} queries each (queries_per_call={queries_per_call})" + ) + + # Create tasks for parallel execution + tasks = [self._generate_batch(batch_id=i, num_queries=target_per_batch) for i in range(num_batches)] + + # Execute in parallel + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Collect results, handling any errors + all_queries: List[GeneratedQuery] = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.warning(f"Batch {i} failed: {result}") + else: + all_queries.extend(result) + + return all_queries + + @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10)) + async def _generate_batch(self, batch_id: int, num_queries: int) -> List[GeneratedQuery]: + """Generate a single batch of queries. + + Args: + batch_id: Batch identifier for diversity + num_queries: Number of queries to generate + + Returns: + List of GeneratedQuery objects + """ + # Format seed queries + seed_queries_text = "None provided" + if self.query_config.seed_queries: + seed_queries_text = "\n".join(f"- {q}" for q in self.query_config.seed_queries) + + # Format categories + categories_text = "No specific categories, generate diverse queries" + if self.query_config.categories: + categories_text = "\n".join( + f"- {cat.get('name', 'unknown')}: weight {cat.get('weight', 1.0)}" + for cat in self.query_config.categories + ) + + # Get existing queries context + existing_context = self.deduplicator.get_existing_summary(max_items=15) + + # Build prompt + messages = QUERY_GENERATION_TEMPLATE.format( + task_description=self.task_config.description, + scenario=self.task_config.scenario or "General usage", + seed_queries=seed_queries_text, + categories=categories_text, + existing_queries=existing_context, + num_queries=num_queries, + batch_id=batch_id, + ) + + # Call model with structured output + response = await self.model.achat( + messages=list(messages), + structured_model=QueryGenerationOutput, + ) + + if not response.parsed or "queries" not in response.parsed: + raise ValueError(f"Failed to parse query generation response for batch {batch_id}") + + # Parse and deduplicate queries + queries: List[GeneratedQuery] = [] + for q in response.parsed["queries"]: + if isinstance(q, dict): + query_obj = GeneratedQuery(**q) + else: + query_obj = q + + # Validate and deduplicate + is_valid, reason = QueryValidator.validate(query_obj) + if not is_valid: + logger.debug(f"Batch {batch_id}: Skipping invalid query: {reason}") + continue + + if self.deduplicator.add(query_obj.query): + queries.append(query_obj) + else: + logger.debug(f"Batch {batch_id}: Skipping duplicate query") + + logger.info(f"Batch {batch_id}: Generated {len(queries)} valid unique queries") + return queries + + # ========================================================================= + # Evol-Instruct Complexity Evolution + # ========================================================================= + + async def _evolve_queries(self, base_queries: List[GeneratedQuery]) -> List[GeneratedQuery]: + """Apply Evol-Instruct style complexity evolution to queries. + + Args: + base_queries: Base queries to evolve + + Returns: + List of evolved queries + """ + if not base_queries: + return [] + + # Select queries for evolution (prefer easier ones and seed queries) + candidates = self._select_evolution_candidates(base_queries) + logger.info(f"Selected {len(candidates)} queries for evolution") + + evolved_queries: List[GeneratedQuery] = [] + + for round_idx in range(self.query_config.evolution_rounds): + logger.info(f"Evolution round {round_idx + 1}/{self.query_config.evolution_rounds}") + + # Create evolution tasks for each strategy + tasks = [] + for query in candidates: + for strategy in self.query_config.complexity_levels: + tasks.append(self._evolve_single(query, strategy)) + + # Execute evolutions in parallel (with some concurrency limit) + semaphore = asyncio.Semaphore(5) + + async def limited_evolve(task): + async with semaphore: + return await task + + results = await asyncio.gather(*[limited_evolve(t) for t in tasks], return_exceptions=True) + + # Collect results + for result in results: + if isinstance(result, Exception): + logger.debug(f"Evolution failed: {result}") + elif result: + evolved_queries.extend(result) + + # Update candidates for next round + if evolved_queries: + candidates = self._select_evolution_candidates(evolved_queries[-10:]) + + return evolved_queries + + def _select_evolution_candidates( + self, queries: List[GeneratedQuery], max_candidates: int = 5 + ) -> List[GeneratedQuery]: + """Select best candidates for evolution.""" + # Prefer easier queries and shorter ones for evolution + scored = [] + for q in queries: + score = 0 + if q.difficulty == "easy": + score += 2 + elif q.difficulty == "medium": + score += 1 + # Prefer medium-length queries + length = len(q.query) + if 20 <= length <= 100: + score += 1 + scored.append((score, q)) + + scored.sort(key=lambda x: -x[0]) + return [q for _, q in scored[:max_candidates]] + + @retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=0.5, max=5)) + async def _evolve_single(self, query: GeneratedQuery, strategy: str) -> List[GeneratedQuery]: + """Evolve a single query using the specified strategy. + + Args: + query: Query to evolve + strategy: Evolution strategy (constraints, reasoning, edge_cases, combination) + + Returns: + List of evolved queries + """ + messages = EVOLUTION_TEMPLATE.format( + original_query=query.query, + strategy=strategy, + num_variations=2, + ) + + try: + response = await self.model.achat( + messages=list(messages), + structured_model=EvolutionOutput, + ) + + if not response.parsed or "evolved_queries" not in response.parsed: + return [] + + evolved: List[GeneratedQuery] = [] + for eq in response.parsed["evolved_queries"]: + if isinstance(eq, dict): + evolved_query = GeneratedQuery( + query=eq.get("query", ""), + category=query.category, # Inherit category + difficulty=eq.get("difficulty", "hard"), + ) + else: + evolved_query = GeneratedQuery( + query=eq.query, + category=query.category, + difficulty=eq.difficulty, + ) + + # Validate and deduplicate + is_valid, _ = QueryValidator.validate(evolved_query) + if is_valid and self.deduplicator.add(evolved_query.query): + evolved.append(evolved_query) + + return evolved + + except Exception as e: + logger.debug(f"Evolution failed for strategy {strategy}: {e}") + return [] + + # ========================================================================= + # Final Filtering and Balancing + # ========================================================================= + + def _final_filter(self, queries: List[GeneratedQuery]) -> List[GeneratedQuery]: + """Apply final filtering and category balancing. + + Args: + queries: All generated queries + + Returns: + Filtered and balanced queries + """ + # Re-validate all queries + valid_queries = [] + for q in queries: + is_valid, _ = QueryValidator.validate(q) + if is_valid: + valid_queries.append(q) + + # Apply category balancing if categories specified + if self.query_config.categories: + return self._balance_categories(valid_queries) + + return valid_queries + + def _balance_categories(self, queries: List[GeneratedQuery]) -> List[GeneratedQuery]: + """Balance queries according to category weights.""" + if not self.query_config.categories: + return queries + + # Calculate target counts per category + total_weight = sum(c.get("weight", 1.0) for c in self.query_config.categories) + target_counts: Dict[str, int] = {} + for cat in self.query_config.categories: + name = cat.get("name", "general") + weight = cat.get("weight", 1.0) + target_counts[name] = max(1, int(self.query_config.num_queries * weight / total_weight)) + + # Group queries by category + by_category: Dict[str, List[GeneratedQuery]] = defaultdict(list) + uncategorized: List[GeneratedQuery] = [] + + for q in queries: + if q.category and q.category in target_counts: + by_category[q.category].append(q) + else: + uncategorized.append(q) + + # Build balanced result + result: List[GeneratedQuery] = [] + + for cat, target in target_counts.items(): + available = by_category.get(cat, []) + result.extend(available[:target]) + + # Fill remaining with uncategorized + remaining = self.query_config.num_queries - len(result) + result.extend(uncategorized[:remaining]) + + return result diff --git a/cookbooks/zero_shot_evaluation/response_collector.py b/cookbooks/zero_shot_evaluation/response_collector.py new file mode 100644 index 00000000..954329aa --- /dev/null +++ b/cookbooks/zero_shot_evaluation/response_collector.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +"""Response collector for zero-shot evaluation.""" + +import asyncio +from typing import Any, Dict, List, Optional + +from loguru import logger +from tenacity import retry, stop_after_attempt, wait_exponential + +from cookbooks.zero_shot_evaluation.schema import ( + EvaluationConfig, + GeneratedQuery, + OpenAIEndpoint, +) +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.oai.message import ChatMessage +from openjudge.utils.concurrency import ConcurrencyManager + + +class ResponseCollector: + """Collect responses from multiple target endpoints.""" + + def __init__( + self, + target_endpoints: Dict[str, OpenAIEndpoint], + evaluation_config: Optional[EvaluationConfig] = None, + ): + """Initialize ResponseCollector. + + Args: + target_endpoints: Dictionary of endpoint name to configuration + evaluation_config: Evaluation configuration + """ + self.endpoints = target_endpoints + self.config = evaluation_config or EvaluationConfig() + + # Initialize models for each endpoint (force stream=False) + self.models: Dict[str, OpenAIChatModel] = {} + self.system_prompts: Dict[str, Optional[str]] = {} + + for name, endpoint in target_endpoints.items(): + extra_params = endpoint.extra_params or {} + # Ensure stream is disabled to avoid async generator issues + extra_params.pop("stream", None) + self.models[name] = OpenAIChatModel( + model=endpoint.model, + api_key=endpoint.api_key, + base_url=endpoint.base_url, + stream=False, + **extra_params, + ) + self.system_prompts[name] = endpoint.system_prompt + + # Setup concurrency manager (use singleton's set method) + self.concurrency_manager = ConcurrencyManager() + self.concurrency_manager.set_max_concurrency(self.config.max_concurrency) + + async def _call_endpoint( + self, + endpoint_name: str, + query: str, + ) -> Dict[str, Any]: + """Call a single endpoint with a query (with retry). + + Args: + endpoint_name: Name of the endpoint + query: Query text + + Returns: + Dictionary with response or error + """ + model = self.models[endpoint_name] + system_prompt = self.system_prompts[endpoint_name] + + messages: List[ChatMessage] = [] + if system_prompt: + messages.append(ChatMessage(role="system", content=system_prompt)) + messages.append(ChatMessage(role="user", content=query)) + + # Create retry decorator with configured retry_times + @retry( + stop=stop_after_attempt(self.config.retry_times), + wait=wait_exponential(multiplier=1, min=1, max=10), + reraise=True, + ) + async def _call_with_retry(): + return await asyncio.wait_for( + model.achat(messages=messages), + timeout=self.config.timeout, + ) + + try: + response = await _call_with_retry() + return { + "endpoint": endpoint_name, + "response": response.content, + "success": True, + } + except asyncio.TimeoutError: + logger.warning(f"Timeout calling {endpoint_name} for query: {query[:50]}...") + return { + "endpoint": endpoint_name, + "response": None, + "success": False, + "error": "timeout", + } + except Exception as e: + logger.warning(f"Error calling {endpoint_name} after {self.config.retry_times} retries: {e}") + return { + "endpoint": endpoint_name, + "response": None, + "success": False, + "error": str(e), + } + + async def collect_single(self, query: str) -> Dict[str, Any]: + """Collect responses from all endpoints for a single query. + + Args: + query: Query text + + Returns: + Dictionary mapping endpoint names to responses + """ + tasks = [ + self.concurrency_manager.run_with_concurrency_control( + self._call_endpoint(name, query), + ) + for name in self.endpoints + ] + + results = await asyncio.gather(*tasks) + + responses = {} + for result in results: + endpoint_name = result["endpoint"] + if result["success"]: + responses[endpoint_name] = result["response"] + else: + responses[endpoint_name] = None + logger.debug(f"Failed response from {endpoint_name}: {result.get('error')}") + + return responses + + async def collect( + self, + queries: List[GeneratedQuery], + ) -> List[Dict[str, Any]]: + """Collect responses from all endpoints for all queries (fully parallel). + + Args: + queries: List of GeneratedQuery objects + + Returns: + List of dictionaries, each containing query and responses + """ + total_calls = len(queries) * len(self.endpoints) + logger.info( + f"Collecting responses for {len(queries)} queries from {len(self.endpoints)} endpoints " + f"({total_calls} total calls, max_concurrency={self.config.max_concurrency})" + ) + + # Create all (query_idx, endpoint_name) tasks + async def _collect_one(query_idx: int, endpoint_name: str) -> Dict[str, Any]: + query_obj = queries[query_idx] + result = await self._call_endpoint(endpoint_name, query_obj.query) + return { + "query_idx": query_idx, + "endpoint": endpoint_name, + "result": result, + } + + # Launch all tasks with concurrency control + tasks = [ + self.concurrency_manager.run_with_concurrency_control(_collect_one(i, ep_name)) + for i in range(len(queries)) + for ep_name in self.endpoints + ] + + # Progress tracking + completed = 0 + all_results = [] + + for coro in asyncio.as_completed(tasks): + result = await coro + all_results.append(result) + completed += 1 + if completed % 10 == 0 or completed == total_calls: + logger.info(f"Progress: {completed}/{total_calls} calls completed") + + # Organize results by query + results_by_query: Dict[int, Dict[str, Any]] = {} + for item in all_results: + query_idx = item["query_idx"] + endpoint = item["endpoint"] + result = item["result"] + + if query_idx not in results_by_query: + query_obj = queries[query_idx] + results_by_query[query_idx] = { + "query": query_obj.query, + "category": query_obj.category, + "difficulty": query_obj.difficulty, + "responses": {}, + } + + if result["success"]: + results_by_query[query_idx]["responses"][endpoint] = result["response"] + else: + results_by_query[query_idx]["responses"][endpoint] = None + logger.debug(f"Failed response from {endpoint}: {result.get('error')}") + + # Convert to ordered list + results = [results_by_query[i] for i in range(len(queries))] + + # Log summary + success_count = sum(1 for r in results if all(v is not None for v in r["responses"].values())) + logger.info(f"Collected responses: {success_count}/{len(results)} queries fully successful") + + return results diff --git a/cookbooks/zero_shot_evaluation/schema.py b/cookbooks/zero_shot_evaluation/schema.py new file mode 100644 index 00000000..62bd9370 --- /dev/null +++ b/cookbooks/zero_shot_evaluation/schema.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +"""Data schemas and configuration loading for zero-shot evaluation. + +This module provides: +- Data models for configuration (OpenAIEndpoint, ZeroShotConfig, etc.) +- Configuration loading utilities (load_config, resolve_env_vars) +""" + +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml +from loguru import logger +from pydantic import BaseModel, Field + +# ============================================================================= +# Data Models +# ============================================================================= + + +class OpenAIEndpoint(BaseModel): + """OpenAI-compatible endpoint configuration. + + This schema is used for all endpoint configurations including: + - Target model endpoints + - Judge model endpoint + - Query generation endpoint (optional) + """ + + base_url: str = Field(..., description="API base URL") + api_key: str = Field(..., description="API key, supports ${ENV_VAR} format") + model: str = Field(..., description="Model name") + system_prompt: Optional[str] = Field(default=None, description="System prompt") + extra_params: Optional[Dict[str, Any]] = Field(default=None, description="Extra request parameters") + + +class TaskConfig(BaseModel): + """Task configuration.""" + + description: str = Field(..., description="Task description") + scenario: Optional[str] = Field(default=None, description="Usage scenario") + + +class QueryGenerationConfig(BaseModel): + """Query generation configuration.""" + + num_queries: int = Field(default=20, description="Number of queries to generate") + seed_queries: Optional[List[str]] = Field(default=None, description="Seed queries for generation") + categories: Optional[List[Dict[str, Any]]] = Field(default=None, description="Query categories") + + # Endpoint configuration (optional, defaults to judge_endpoint if not specified) + # Uses OpenAIEndpoint for consistency + endpoint: Optional[OpenAIEndpoint] = Field( + default=None, + description="Custom endpoint for query generation. If not set, uses judge_endpoint.", + ) + + # Diversity control parameters + temperature: float = Field(default=0.9, ge=0.0, le=2.0, description="Sampling temperature for diversity") + top_p: float = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") + + # Batch generation parameters + queries_per_call: int = Field(default=10, ge=1, le=50, description="Number of queries to generate per API call") + num_parallel_batches: int = Field(default=3, ge=1, description="Number of parallel batches") + max_similarity: float = Field(default=0.85, ge=0.0, le=1.0, description="Max similarity threshold for dedup") + + # Evol-Instruct parameters + enable_evolution: bool = Field(default=False, description="Enable complexity evolution") + evolution_rounds: int = Field(default=1, ge=0, le=3, description="Number of evolution rounds") + complexity_levels: List[str] = Field( + default=["constraints", "reasoning", "edge_cases"], + description="Complexity evolution strategies", + ) + + +class EvaluationConfig(BaseModel): + """Evaluation configuration.""" + + max_concurrency: int = Field(default=10, description="Maximum concurrency") + timeout: int = Field(default=60, description="Request timeout in seconds") + retry_times: int = Field(default=3, description="Number of retries") + + +class OutputConfig(BaseModel): + """Output configuration.""" + + save_queries: bool = Field(default=True, description="Save generated queries") + save_responses: bool = Field(default=True, description="Save all responses") + save_details: bool = Field(default=True, description="Save detailed results") + output_dir: str = Field(default="./evaluation_results", description="Output directory") + + +class ZeroShotConfig(BaseModel): + """Complete zero-shot evaluation configuration.""" + + task: TaskConfig + target_endpoints: Dict[str, OpenAIEndpoint] + judge_endpoint: OpenAIEndpoint + query_generation: QueryGenerationConfig = Field(default_factory=QueryGenerationConfig) + evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) + output: OutputConfig = Field(default_factory=OutputConfig) + + +class GeneratedQuery(BaseModel): + """Generated query item.""" + + query: str = Field(..., description="The query text") + category: Optional[str] = Field(default=None, description="Query category") + difficulty: Optional[str] = Field(default=None, description="Query difficulty") + + +class QueryGenerationOutput(BaseModel): + """Output schema for query generation.""" + + queries: List[GeneratedQuery] = Field(..., description="List of generated queries") + reason: str = Field(default="", description="Generation reasoning") + + +# ============================================================================= +# Configuration Loading +# ============================================================================= + + +def resolve_env_vars(value: Any) -> Any: + """Resolve environment variables in configuration values. + + Supports ${VAR_NAME} format. + + Args: + value: Configuration value (can be str, dict, or list) + + Returns: + Value with environment variables resolved + """ + if isinstance(value, str): + pattern = r"\$\{(\w+)\}" + matches = re.findall(pattern, value) + for var_name in matches: + env_value = os.getenv(var_name, "") + if not env_value: + logger.warning(f"Environment variable {var_name} not set") + value = value.replace(f"${{{var_name}}}", env_value) + return value + elif isinstance(value, dict): + return {k: resolve_env_vars(v) for k, v in value.items()} + elif isinstance(value, list): + return [resolve_env_vars(item) for item in value] + return value + + +def load_config(config_path: Union[str, Path]) -> ZeroShotConfig: + """Load and validate configuration from YAML file. + + Args: + config_path: Path to the configuration file + + Returns: + Validated ZeroShotConfig object + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If config validation fails + """ + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + with open(config_path, "r", encoding="utf-8") as f: + raw_config = yaml.safe_load(f) + + # Resolve environment variables + resolved_config = resolve_env_vars(raw_config) + + # Validate and create config object + config = ZeroShotConfig(**resolved_config) + logger.info(f"Loaded configuration from {config_path}") + logger.info(f"Task: {config.task.description}") + logger.info(f"Target endpoints: {list(config.target_endpoints.keys())}") + + return config + + +def config_to_dict(config: ZeroShotConfig) -> Dict[str, Any]: + """Convert ZeroShotConfig to dictionary (for serialization). + + Args: + config: ZeroShotConfig object + + Returns: + Dictionary representation + """ + return config.model_dump() diff --git a/cookbooks/zero_shot_evaluation/zero_shot_pipeline.py b/cookbooks/zero_shot_evaluation/zero_shot_pipeline.py new file mode 100644 index 00000000..03f5e3d6 --- /dev/null +++ b/cookbooks/zero_shot_evaluation/zero_shot_pipeline.py @@ -0,0 +1,736 @@ +# -*- coding: utf-8 -*- +"""End-to-end pipeline for zero-shot evaluation. + +This module provides the ZeroShotPipeline class for end-to-end evaluation +of AI models without labeled data. It integrates with OpenJudge's core +components for grading, analysis, and rubric generation. + +Pipeline Steps: + 1. Generate test queries + 2. Collect responses from target endpoints + 3. Generate evaluation rubrics + 4. Run pairwise evaluation + 5. Analyze and rank results +""" + +import json +from datetime import datetime +from enum import Enum +from itertools import combinations +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from loguru import logger +from pydantic import BaseModel, Field + +from cookbooks.zero_shot_evaluation.query_generator import QueryGenerator +from cookbooks.zero_shot_evaluation.response_collector import ResponseCollector +from cookbooks.zero_shot_evaluation.schema import ( + GeneratedQuery, + OpenAIEndpoint, + ZeroShotConfig, + load_config, +) + +# OpenJudge core components +from openjudge.analyzer import PairwiseAnalysisResult, PairwiseAnalyzer +from openjudge.generator.simple_rubric import TaskBasedRubricGenerator +from openjudge.graders.llm_grader import GraderMode, LLMGrader +from openjudge.graders.schema import GraderResult +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.oai.message import ChatMessage +from openjudge.models.schema.prompt_template import PromptTemplate +from openjudge.runner.grading_runner import GraderConfig, GradingRunner + +# ============================================================================= +# Checkpoint Management (integrated from checkpoint.py) +# ============================================================================= + + +class EvaluationStage(str, Enum): + """Evaluation pipeline stages.""" + + NOT_STARTED = "not_started" + QUERIES_GENERATED = "queries_generated" + RESPONSES_COLLECTED = "responses_collected" + RUBRICS_GENERATED = "rubrics_generated" + EVALUATION_COMPLETE = "evaluation_complete" + + +class _CheckpointData(BaseModel): + """Internal checkpoint data model.""" + + stage: EvaluationStage = Field(default=EvaluationStage.NOT_STARTED) + created_at: str = Field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = Field(default_factory=lambda: datetime.now().isoformat()) + + # Data files + queries_file: Optional[str] = None + responses_file: Optional[str] = None + rubrics_file: Optional[str] = None + + # Progress tracking + total_queries: int = 0 + collected_responses: int = 0 + evaluated_pairs: int = 0 + total_pairs: int = 0 + + +class _CheckpointManager: + """Internal checkpoint manager for evaluation pipeline resume capability.""" + + CHECKPOINT_FILE = "checkpoint.json" + QUERIES_FILE = "queries.json" + RESPONSES_FILE = "responses.json" + RUBRICS_FILE = "rubrics.json" + + def __init__(self, output_dir: str): + """Initialize checkpoint manager. + + Args: + output_dir: Directory to store checkpoint files + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self._checkpoint: Optional[_CheckpointData] = None + + @property + def checkpoint_path(self) -> Path: + return self.output_dir / self.CHECKPOINT_FILE + + def load(self) -> Optional[_CheckpointData]: + """Load existing checkpoint if available.""" + if not self.checkpoint_path.exists(): + logger.info("No checkpoint found, starting fresh") + return None + + try: + with open(self.checkpoint_path, "r", encoding="utf-8") as f: + data = json.load(f) + self._checkpoint = _CheckpointData(**data) + logger.info(f"Loaded checkpoint: stage={self._checkpoint.stage.value}") + return self._checkpoint + except Exception as e: + logger.warning(f"Failed to load checkpoint: {e}") + return None + + def save(self, checkpoint: _CheckpointData) -> None: + """Save checkpoint to file.""" + checkpoint.updated_at = datetime.now().isoformat() + self._checkpoint = checkpoint + + with open(self.checkpoint_path, "w", encoding="utf-8") as f: + json.dump(checkpoint.model_dump(), f, indent=2, ensure_ascii=False) + + logger.debug(f"Checkpoint saved: stage={checkpoint.stage.value}") + + def save_queries(self, queries: List[GeneratedQuery]) -> str: + """Save generated queries.""" + file_path = self.output_dir / self.QUERIES_FILE + + with open(file_path, "w", encoding="utf-8") as f: + json.dump([q.model_dump() for q in queries], f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(queries)} queries to {file_path}") + return str(file_path) + + def load_queries(self) -> List[GeneratedQuery]: + """Load saved queries.""" + file_path = self.output_dir / self.QUERIES_FILE + + if not file_path.exists(): + return [] + + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + queries = [GeneratedQuery(**item) for item in data] + logger.info(f"Loaded {len(queries)} queries from {file_path}") + return queries + + def save_responses(self, responses: List[Dict[str, Any]]) -> str: + """Save collected responses.""" + file_path = self.output_dir / self.RESPONSES_FILE + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(responses, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(responses)} responses to {file_path}") + return str(file_path) + + def load_responses(self) -> List[Dict[str, Any]]: + """Load saved responses.""" + file_path = self.output_dir / self.RESPONSES_FILE + + if not file_path.exists(): + return [] + + with open(file_path, "r", encoding="utf-8") as f: + responses = json.load(f) + + logger.info(f"Loaded {len(responses)} responses from {file_path}") + return responses + + def save_rubrics(self, rubrics: List[str]) -> str: + """Save generated rubrics.""" + file_path = self.output_dir / self.RUBRICS_FILE + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(rubrics, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(rubrics)} rubrics to {file_path}") + return str(file_path) + + def load_rubrics(self) -> List[str]: + """Load saved rubrics.""" + file_path = self.output_dir / self.RUBRICS_FILE + + if not file_path.exists(): + return [] + + with open(file_path, "r", encoding="utf-8") as f: + rubrics = json.load(f) + + logger.info(f"Loaded {len(rubrics)} rubrics from {file_path}") + return rubrics + + def update_stage( + self, + stage: EvaluationStage, + **kwargs, + ) -> None: + """Update checkpoint stage and save.""" + if self._checkpoint is None: + self._checkpoint = _CheckpointData() + + self._checkpoint.stage = stage + for key, value in kwargs.items(): + if hasattr(self._checkpoint, key): + setattr(self._checkpoint, key, value) + + self.save(self._checkpoint) + + def clear(self) -> None: + """Clear all checkpoint data.""" + for file_name in [ + self.CHECKPOINT_FILE, + self.QUERIES_FILE, + self.RESPONSES_FILE, + self.RUBRICS_FILE, + ]: + file_path = self.output_dir / file_name + if file_path.exists(): + file_path.unlink() + + self._checkpoint = None + logger.info("Checkpoint cleared") + + +# ============================================================================= +# Evaluation Result +# ============================================================================= + + +class EvaluationResult(BaseModel): + """Result of zero-shot evaluation. + + Attributes: + rankings: List of (model_name, win_rate) tuples sorted by win rate + win_rates: Win rate for each model + win_matrix: Win rate matrix where win_matrix[A][B] = how often A beats B + best_pipeline: Name of the best performing pipeline + total_queries: Total number of queries evaluated + total_comparisons: Total number of pairwise comparisons + """ + + rankings: List[Tuple[str, float]] = Field(default_factory=list) + win_rates: Dict[str, float] = Field(default_factory=dict) + win_matrix: Dict[str, Dict[str, float]] = Field(default_factory=dict) + best_pipeline: str = Field(default="") + total_queries: int = Field(default=0) + total_comparisons: int = Field(default=0) + + @classmethod + def from_analysis(cls, analysis: PairwiseAnalysisResult, total_queries: int) -> "EvaluationResult": + """Create EvaluationResult from PairwiseAnalysisResult. + + Args: + analysis: Analysis result from PairwiseAnalyzer + total_queries: Total number of queries evaluated + + Returns: + EvaluationResult instance + """ + return cls( + rankings=analysis.rankings, + win_rates=analysis.win_rates, + win_matrix=analysis.win_matrix, + best_pipeline=analysis.best_model, + total_queries=total_queries, + total_comparisons=analysis.total_comparisons, + ) + + +# ============================================================================= +# Zero-Shot Pipeline +# ============================================================================= + + +class ZeroShotPipeline: + """End-to-end zero-shot evaluation pipeline with checkpoint support. + + This pipeline automates the complete evaluation process: + 1. Generate diverse test queries based on task description + 2. Collect responses from multiple target endpoints + 3. Generate evaluation rubrics using LLM + 4. Run pairwise comparisons between model responses + 5. Analyze results and rank models + + The pipeline integrates with OpenJudge's core components: + - Uses TaskBasedRubricGenerator from openjudge.generator.simple_rubric for rubric generation + - Uses PairwiseAnalyzer from openjudge.analyzer for result analysis + - Uses LLMGrader and GradingRunner for pairwise evaluation + + Attributes: + config: Pipeline configuration + _queries: Generated queries + _responses: Collected responses + _rubrics: Generated rubrics + + Example: + >>> from cookbooks.zero_shot_evaluation import ZeroShotPipeline + >>> pipeline = ZeroShotPipeline.from_config("config.yaml") + >>> result = await pipeline.evaluate() + >>> print(f"Best model: {result.best_pipeline}") + """ + + def __init__( + self, + config: Optional[ZeroShotConfig] = None, + *, + task_description: Optional[str] = None, + target_endpoints: Optional[Dict[str, OpenAIEndpoint]] = None, + judge_endpoint: Optional[OpenAIEndpoint] = None, + num_queries: int = 20, + resume: bool = True, + ): + """Initialize ZeroShotPipeline. + + Args: + config: Complete configuration object + task_description: Task description (alternative to config) + target_endpoints: Target endpoints (alternative to config) + judge_endpoint: Judge endpoint (alternative to config) + num_queries: Number of queries to generate + resume: Whether to resume from checkpoint if available + """ + if config: + self.config = config + else: + if not all([task_description, target_endpoints, judge_endpoint]): + raise ValueError("Must provide either config or all individual parameters") + from cookbooks.zero_shot_evaluation.schema import ( + EvaluationConfig, + QueryGenerationConfig, + TaskConfig, + ) + + self.config = ZeroShotConfig( + task=TaskConfig(description=task_description), + target_endpoints=target_endpoints, + judge_endpoint=judge_endpoint, + query_generation=QueryGenerationConfig(num_queries=num_queries), + evaluation=EvaluationConfig(), + ) + + self._queries: List[GeneratedQuery] = [] + self._responses: List[Dict[str, Any]] = [] + self._rubrics: List[str] = [] + + # Initialize checkpoint manager + self._checkpoint_mgr = _CheckpointManager(self.config.output.output_dir) + self._resume = resume + + @classmethod + def from_config(cls, config_path: Union[str, Path]) -> "ZeroShotPipeline": + """Create pipeline from configuration file. + + Args: + config_path: Path to YAML configuration file + + Returns: + ZeroShotPipeline instance + """ + config = load_config(config_path) + return cls(config=config) + + def _create_judge_model(self) -> OpenAIChatModel: + """Create judge model from endpoint configuration.""" + endpoint = self.config.judge_endpoint + extra_params = endpoint.extra_params or {} + return OpenAIChatModel( + model=endpoint.model, + api_key=endpoint.api_key, + base_url=endpoint.base_url, + **extra_params, + ) + + async def generate_queries(self) -> List[GeneratedQuery]: + """Step 1: Generate test queries.""" + logger.info("Step 1: Generating test queries...") + generator = QueryGenerator( + judge_endpoint=self.config.judge_endpoint, + task_config=self.config.task, + query_config=self.config.query_generation, + ) + self._queries = await generator.generate() + return self._queries + + async def collect_responses( + self, + queries: Optional[List[GeneratedQuery]] = None, + ) -> List[Dict[str, Any]]: + """Step 2: Collect responses from all target endpoints.""" + queries = queries or self._queries + if not queries: + raise ValueError("No queries available. Run generate_queries() first.") + + logger.info("Step 2: Collecting responses from target endpoints...") + collector = ResponseCollector( + target_endpoints=self.config.target_endpoints, + evaluation_config=self.config.evaluation, + ) + self._responses = await collector.collect(queries) + return self._responses + + async def generate_rubrics( + self, + sample_queries: Optional[List[str]] = None, + ) -> List[str]: + """Step 3: Generate evaluation rubrics using OpenJudge's TaskBasedRubricGenerator.""" + logger.info("Step 3: Generating evaluation rubrics...") + + if not sample_queries and self._queries: + sample_queries = [q.query for q in self._queries[:5]] + + # Use OpenJudge's TaskBasedRubricGenerator + generator = TaskBasedRubricGenerator( + model=self._create_judge_model(), + task_description=self.config.task.description, + scenario=self.config.task.scenario, + ) + self._rubrics = await generator.generate(sample_queries) + return self._rubrics + + def _prepare_pairwise_data( + self, + responses: List[Dict[str, Any]], + ) -> Tuple[List[dict], List[str]]: + """Prepare pairwise comparison dataset. + + Creates comparison pairs for all model combinations, with both + original and swapped orders to eliminate position bias. + + Args: + responses: List of response data from collect_responses() + + Returns: + Tuple of (dataset, endpoint_names) + """ + endpoint_names = list(self.config.target_endpoints.keys()) + pairs = list(combinations(endpoint_names, 2)) + + dataset = [] + for resp_data in responses: + query = resp_data["query"] + resp_dict = resp_data["responses"] + + for ep_a, ep_b in pairs: + resp_a = resp_dict.get(ep_a) + resp_b = resp_dict.get(ep_b) + + if resp_a is None or resp_b is None: + continue + + # Original order + dataset.append( + { + "evaluation_data": { + "instruction": query, + "response_a": resp_a, + "response_b": resp_b, + }, + "metadata": { + "model_a": ep_a, + "model_b": ep_b, + "order": "original", + }, + } + ) + # Swapped order (to eliminate position bias) + dataset.append( + { + "evaluation_data": { + "instruction": query, + "response_a": resp_b, + "response_b": resp_a, + }, + "metadata": { + "model_a": ep_b, + "model_b": ep_a, + "order": "swapped", + }, + } + ) + + return dataset, endpoint_names + + def _build_pairwise_grader(self, rubrics: List[str]) -> LLMGrader: + """Build pairwise comparison grader.""" + rubrics_text = "\n".join(f"- {r}" for r in rubrics) + + template = PromptTemplate( + messages=[ + ChatMessage( + role="system", + content="You are an expert evaluator. Compare two responses based on the given criteria.\n" + f"Evaluation Criteria:\n{rubrics_text}\n\n" + "Output JSON with 'score' (1.0 if Response A is better, 0.0 if Response B is better) " + "and 'reason' (brief explanation).", + ), + ChatMessage( + role="user", + content="Query: {instruction}\n\n" + "Response A:\n{response_a}\n\n" + "Response B:\n{response_b}\n\n" + "Which response is better based on the criteria?", + ), + ], + ) + + endpoint = self.config.judge_endpoint + extra_params = endpoint.extra_params or {} + + return LLMGrader( + name="pairwise_comparator", + mode=GraderMode.POINTWISE, + model=OpenAIChatModel( + model=endpoint.model, + api_key=endpoint.api_key, + base_url=endpoint.base_url, + temperature=extra_params.get("temperature", 0.1), + ), + template=template, + ) + + async def _run_pairwise_evaluation( + self, + dataset: List[dict], + rubrics: List[str], + ) -> List[GraderResult]: + """Run pairwise evaluation using GradingRunner.""" + grader = self._build_pairwise_grader(rubrics) + + mapper = { + "instruction": "evaluation_data.instruction", + "response_a": "evaluation_data.response_a", + "response_b": "evaluation_data.response_b", + } + + runner = GradingRunner( + grader_configs={ + "pairwise": GraderConfig(grader=grader, mapper=mapper), + }, + max_concurrency=self.config.evaluation.max_concurrency, + ) + + logger.info(f"Running {len(dataset)} pairwise comparisons...") + results = await runner.arun(dataset) + return results["pairwise"] + + def _analyze_results( + self, + dataset: List[dict], + grader_results: List[GraderResult], + endpoint_names: List[str], + ) -> EvaluationResult: + """Analyze pairwise comparison results using OpenJudge's PairwiseAnalyzer.""" + # Use OpenJudge's PairwiseAnalyzer for analysis + analyzer = PairwiseAnalyzer(model_names=endpoint_names) + analysis = analyzer.analyze(dataset, grader_results) + + # Convert to EvaluationResult + return EvaluationResult.from_analysis(analysis, total_queries=len(self._responses)) + + async def evaluate( + self, + queries: Optional[List[GeneratedQuery]] = None, + rubrics: Optional[List[str]] = None, + ) -> EvaluationResult: + """Run complete evaluation pipeline with checkpoint support. + + Args: + queries: Optional pre-generated queries + rubrics: Optional pre-generated rubrics + + Returns: + EvaluationResult with rankings + """ + # Try to resume from checkpoint + checkpoint = None + if self._resume: + checkpoint = self._checkpoint_mgr.load() + + # Step 1: Generate or load queries + if queries: + self._queries = queries + logger.info(f"Using {len(queries)} provided queries") + elif checkpoint and checkpoint.stage.value >= EvaluationStage.QUERIES_GENERATED.value: + self._queries = self._checkpoint_mgr.load_queries() + logger.info(f"Resumed {len(self._queries)} queries from checkpoint") + elif not self._queries: + await self.generate_queries() + # Save checkpoint + self._checkpoint_mgr.save_queries(self._queries) + self._checkpoint_mgr.update_stage( + EvaluationStage.QUERIES_GENERATED, + total_queries=len(self._queries), + queries_file=str(self._checkpoint_mgr.output_dir / "queries.json"), + ) + + # Step 2: Collect or load responses + if checkpoint and checkpoint.stage.value >= EvaluationStage.RESPONSES_COLLECTED.value: + self._responses = self._checkpoint_mgr.load_responses() + logger.info(f"Resumed {len(self._responses)} responses from checkpoint") + elif not self._responses: + await self.collect_responses() + # Save checkpoint + self._checkpoint_mgr.save_responses(self._responses) + self._checkpoint_mgr.update_stage( + EvaluationStage.RESPONSES_COLLECTED, + collected_responses=len(self._responses), + responses_file=str(self._checkpoint_mgr.output_dir / "responses.json"), + ) + + # Step 3: Generate or load rubrics + if rubrics: + self._rubrics = rubrics + logger.info(f"Using {len(rubrics)} provided rubrics") + elif checkpoint and checkpoint.stage.value >= EvaluationStage.RUBRICS_GENERATED.value: + self._rubrics = self._checkpoint_mgr.load_rubrics() + logger.info(f"Resumed {len(self._rubrics)} rubrics from checkpoint") + elif not self._rubrics: + await self.generate_rubrics() + # Save checkpoint + self._checkpoint_mgr.save_rubrics(self._rubrics) + self._checkpoint_mgr.update_stage( + EvaluationStage.RUBRICS_GENERATED, + rubrics_file=str(self._checkpoint_mgr.output_dir / "rubrics.json"), + ) + + # Step 4: Run pairwise evaluation + logger.info("Step 4: Running pairwise evaluation...") + dataset, endpoint_names = self._prepare_pairwise_data(self._responses) + + if not dataset: + raise ValueError("No valid comparison pairs. Check if responses were collected successfully.") + + grader_results = await self._run_pairwise_evaluation(dataset, self._rubrics) + + # Step 5: Analyze results using OpenJudge's PairwiseAnalyzer + logger.info("Step 5: Analyzing results...") + result = self._analyze_results(dataset, grader_results, endpoint_names) + + # Mark evaluation complete + self._checkpoint_mgr.update_stage( + EvaluationStage.EVALUATION_COMPLETE, + total_pairs=len(dataset), + evaluated_pairs=len(grader_results), + ) + + self._display_results(result) + return result + + def _display_results(self, result: EvaluationResult) -> None: + """Display evaluation results with formatted output.""" + endpoint_names = list(self.config.target_endpoints.keys()) + + # Header + logger.info("\n" + "=" * 60) + logger.info("ZERO-SHOT EVALUATION RESULTS") + logger.info("=" * 60) + + # Summary + logger.info(f"Task: {self.config.task.description[:50]}...") + logger.info(f"Queries: {result.total_queries}") + logger.info(f"Comparisons: {result.total_comparisons}") + + # Rankings + logger.info("\nRankings:") + for rank, (name, win_rate) in enumerate(result.rankings, 1): + bar_len = int(win_rate * 20) + bar = "#" * bar_len + "-" * (20 - bar_len) + logger.info(f" {rank}. {name:<20} [{bar}] {win_rate:.1%}") + + # Win Matrix + if len(endpoint_names) > 1: + logger.info("\nWin Matrix (row vs column):") + # Header row + max_name_len = max(len(n) for n in endpoint_names) + header = " " * (max_name_len + 3) + "".join(f"{n[:8]:<10}" for n in endpoint_names) + logger.info(f" {header}") + + # Data rows + for ep_a in endpoint_names: + row = f" {ep_a:<{max_name_len}} | " + for ep_b in endpoint_names: + if ep_a == ep_b: + row += f"{'--':<10}" + else: + win_rate = result.win_matrix.get(ep_a, {}).get(ep_b, 0.0) + row += f"{win_rate:<10.1%}" + logger.info(row) + + # Best pipeline + logger.info(f"\nBest Pipeline: {result.best_pipeline}") + logger.info("=" * 60) + + def save_results( + self, + result: EvaluationResult, + output_dir: Optional[Union[str, Path]] = None, + ) -> Path: + """Save evaluation results to file. + + Args: + result: Evaluation result + output_dir: Output directory (uses config default if not provided) + + Returns: + Path to saved file + """ + output_dir = Path(output_dir or self.config.output.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_file = output_dir / "evaluation_results.json" + + data = { + "result": result.model_dump(), + "config": { + "task": self.config.task.model_dump(), + "target_endpoints": list(self.config.target_endpoints.keys()), + "num_queries": self.config.query_generation.num_queries, + }, + "queries": [q.model_dump() for q in self._queries], + "rubrics": self._rubrics, + } + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + logger.info(f"Results saved to {output_file}") + return output_file + + def clear_checkpoint(self) -> None: + """Clear all checkpoint data to start fresh.""" + self._checkpoint_mgr.clear() diff --git a/docs/applications/zero_shot_evaluation.md b/docs/applications/zero_shot_evaluation.md new file mode 100644 index 00000000..cdd29507 --- /dev/null +++ b/docs/applications/zero_shot_evaluation.md @@ -0,0 +1,372 @@ +# Zero-Shot Evaluation + +Automatically evaluate and compare multiple models or AI agents without pre-existing test data. This end-to-end pipeline generates test queries, collects responses, and ranks models through pairwise comparison. + + +## When to Use + +Use zero-shot evaluation for: + +- **Model Comparison** — Compare different models on a specific task without preparing test data +- **Agent Pipeline Testing** — Evaluate different agent configurations or workflows +- **New Domain Evaluation** — Quickly assess model performance in new domains +- **Rapid Prototyping** — Get quick feedback on model quality during development + + +## How It Works + +Zero-shot evaluation automates the entire evaluation pipeline: + +1. **Generate Test Queries** — Create diverse, representative queries based on task description +2. **Collect Responses** — Query all target models/agents to collect responses +3. **Generate Rubrics** — Create evaluation criteria tailored to the task +4. **Pairwise Comparison** — Compare all response pairs using a judge model +5. **Rank Models** — Calculate win rates and produce final rankings + +!!! tip "No Test Data Required" + Unlike traditional evaluation, zero-shot evaluation generates its own test queries from the task description, eliminating the need for pre-existing test datasets. + + +## Five-Step Pipeline + +| Step | Component | Description | +|------|-----------|-------------| +| 1 | `QueryGenerator` | Generate diverse test queries from task description | +| 2 | `ResponseCollector` | Collect responses from all target endpoints | +| 3 | `TaskBasedRubricGenerator` | Generate evaluation criteria for the task | +| 4 | `GradingRunner` | Run pairwise comparisons with judge model | +| 5 | `ZeroShotPipeline` | Analyze results and produce rankings | + + +## Quick Start + +### Using Configuration File (Recommended) + +```python +import asyncio +from cookbooks.zero_shot_evaluation.zero_shot_pipeline import ZeroShotPipeline + +async def main(): + pipeline = ZeroShotPipeline.from_config("config.yaml") + result = await pipeline.evaluate() + + print(f"Best Model: {result.best_pipeline}") + for rank, (model, win_rate) in enumerate(result.rankings, 1): + print(f"{rank}. {model}: {win_rate:.1%}") + +asyncio.run(main()) +``` + +### Using CLI + +```bash +# Run evaluation with config file +python -m cookbooks.zero_shot_evaluation --config config.yaml --save + +# Resume from checkpoint (default behavior) +python -m cookbooks.zero_shot_evaluation --config config.yaml --save + +# Start fresh, ignore checkpoint +python -m cookbooks.zero_shot_evaluation --config config.yaml --fresh --save + +# Use pre-generated queries +python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save +``` + +### Using Pre-defined Queries + +Skip query generation by providing your own queries file. This is useful when you want to evaluate models on a specific set of questions. + +**Create a queries file** (`queries.json`): + +```json +[ + {"query": "Translate: AI is transforming industries."}, + {"query": "Translate: The weather is nice today."}, + {"query": "Translate: How to learn programming effectively?"} +] +``` + +The `category` and `difficulty` fields are optional: + +```json +[ + {"query": "Your question here", "category": "general", "difficulty": "easy"} +] +``` + +**Run evaluation**: + +```bash +python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save +``` + +The pipeline will skip query generation and directly use your queries for model comparison. + + +## Configuration + +Create a YAML configuration file to define your evaluation: + +```yaml +# Task description +task: + description: "English to Chinese translation assistant" + scenario: "Users need to translate English content into fluent Chinese" + +# Target endpoints to evaluate +target_endpoints: + gpt4_baseline: + base_url: "https://api.openai.com/v1" + api_key: "${OPENAI_API_KEY}" + model: "gpt-4" + extra_params: + temperature: 0.7 + + qwen_candidate: + base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + api_key: "${DASHSCOPE_API_KEY}" + model: "qwen-max" + extra_params: + temperature: 0.7 + +# Judge endpoint for pairwise evaluation +judge_endpoint: + base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + api_key: "${DASHSCOPE_API_KEY}" + model: "qwen-max" + extra_params: + temperature: 0.1 + +# Query generation settings +query_generation: + num_queries: 20 + seed_queries: + - "Translate this paragraph into Chinese: 'AI is transforming industries.'" + queries_per_call: 10 + temperature: 0.9 + +# Evaluation settings +evaluation: + max_concurrency: 10 + timeout: 60 + +# Output settings +output: + output_dir: "./evaluation_results" +``` + +!!! note "Environment Variables" + Use `${ENV_VAR}` syntax to reference environment variables for sensitive data like API keys. + + +## Step-by-Step Guide + +For fine-grained control, use individual components directly: + +### Step 1: Generate Test Queries + +```python +from cookbooks.zero_shot_evaluation.query_generator import QueryGenerator +from cookbooks.zero_shot_evaluation.schema import TaskConfig, QueryGenerationConfig, OpenAIEndpoint + +# Configure task and endpoint +task = TaskConfig( + description="Code review assistant for Python", + scenario="Review code for bugs, style issues, and improvements" +) + +judge_endpoint = OpenAIEndpoint( + base_url="https://api.openai.com/v1", + api_key="your-api-key", + model="gpt-4" +) + +query_config = QueryGenerationConfig( + num_queries=20, + seed_queries=["Review this Python function for bugs..."], + enable_evolution=True, # Enable Evol-Instruct + evolution_rounds=1 +) + +generator = QueryGenerator(judge_endpoint, task, query_config) +queries = await generator.generate() +``` + +!!! info "Query Generation Features" + - **Parallel Batches**: Generates queries in parallel for diversity + - **Deduplication**: Automatically removes duplicate/similar queries + - **Evol-Instruct**: Optional complexity evolution for harder queries + - **Category Balancing**: Balance queries across specified categories + +### Step 2: Collect Responses + +```python +from cookbooks.zero_shot_evaluation.response_collector import ResponseCollector +from cookbooks.zero_shot_evaluation.schema import EvaluationConfig + +collector = ResponseCollector( + target_endpoints={ + "model_a": endpoint_a, + "model_b": endpoint_b, + }, + evaluation_config=EvaluationConfig(max_concurrency=10) +) + +responses = await collector.collect(queries) +``` + +### Step 3: Generate Evaluation Rubrics + +```python +from openjudge.generator.simple_rubric import TaskBasedRubricGenerator + +rubric_gen = TaskBasedRubricGenerator( + model=judge_model, + task_description=task.description, + scenario=task.scenario, +) +rubrics = await rubric_gen.generate( + sample_queries=[q.query for q in queries[:5]] +) + +# Example output: +# - Accuracy: Whether the response is factually correct +# - Completeness: Whether the response fully addresses the query +# - Clarity: Whether the response is well-organized +``` + +### Step 4: Run Full Evaluation + +```python +from cookbooks.zero_shot_evaluation.zero_shot_pipeline import ZeroShotPipeline + +pipeline = ZeroShotPipeline( + task_description="Code review assistant", + target_endpoints=target_endpoints, + judge_endpoint=judge_endpoint, + num_queries=20 +) + +result = await pipeline.evaluate() +``` + + +## Understanding Results + +The `EvaluationResult` provides comprehensive ranking statistics: + +| Field | Type | Description | +|-------|------|-------------| +| `rankings` | `List[Tuple[str, float]]` | Models sorted by win rate (best first) | +| `win_rates` | `Dict[str, float]` | Win rate for each model (0.0-1.0) | +| `win_matrix` | `Dict[str, Dict[str, float]]` | Head-to-head win rates between models | +| `best_pipeline` | `str` | Model with highest win rate | +| `total_queries` | `int` | Total number of test queries | +| `total_comparisons` | `int` | Total number of pairwise comparisons | + +!!! example "Sample Output" + ``` + ============================================================ + ZERO-SHOT EVALUATION RESULTS + ============================================================ + Task: English to Chinese translation assistant... + Queries: 20 + Comparisons: 80 + + Rankings: + 1. qwen_candidate [################----] 80.0% + 2. gpt4_baseline [########------------] 40.0% + + Win Matrix (row vs column): + qwen_cand gpt4_base + qwen_candidate | -- 80.0% + gpt4_baseline | 20.0% -- + + Best Pipeline: qwen_candidate + ============================================================ + ``` + + +## Advanced Configuration + +### Query Generation Options + +| Option | Default | Description | +|--------|---------|-------------| +| `num_queries` | 20 | Total number of queries to generate | +| `queries_per_call` | 10 | Queries per API call (1-50) | +| `num_parallel_batches` | 3 | Number of parallel generation batches | +| `temperature` | 0.9 | Sampling temperature for diversity | +| `max_similarity` | 0.85 | Deduplication similarity threshold | +| `enable_evolution` | false | Enable Evol-Instruct complexity evolution | +| `evolution_rounds` | 1 | Number of evolution rounds (0-3) | + +### Evol-Instruct Evolution + +Enable complexity evolution to generate harder test queries: + +```yaml +query_generation: + enable_evolution: true + evolution_rounds: 2 + complexity_levels: + - "constraints" # Add specific constraints + - "reasoning" # Require multi-step reasoning + - "edge_cases" # Add edge cases and exceptions +``` + +!!! tip "Evolution Strategies" + - **constraints**: Add time, scope, or condition constraints + - **reasoning**: Require multi-step reasoning or comparison + - **edge_cases**: Include edge cases and unusual conditions + + +## Checkpoint & Resume + +Evaluations automatically save checkpoints, allowing resumption after interruptions: + +```bash +# First run (interrupted) +python -m cookbooks.zero_shot_evaluation --config config.yaml --save +# Progress saved at: ./evaluation_results/checkpoint.json + +# Resume from checkpoint (automatic) +python -m cookbooks.zero_shot_evaluation --config config.yaml --save +# Resumes from last completed step + +# Start fresh (ignore checkpoint) +python -m cookbooks.zero_shot_evaluation --config config.yaml --fresh --save +``` + +!!! info "Checkpoint Stages" + 1. `QUERIES_GENERATED` — Test queries saved + 2. `RESPONSES_COLLECTED` — All responses saved + 3. `RUBRICS_GENERATED` — Evaluation rubrics saved + 4. `EVALUATION_COMPLETE` — Final results saved + + +## Best Practices + +!!! tip "Do" + - Start with a **clear task description** that captures the core objective + - Use **seed queries** to guide query generation style + - Set `num_queries` to at least **20** for statistically meaningful results + - Choose a **strong judge model** (at least as capable as models being evaluated) + - Use `--save` flag to persist results for later analysis + +!!! warning "Don't" + - Use a judge model weaker than the models being evaluated + - Set `max_concurrency` too high for your API rate limits + - Skip checkpoint resumption for long-running evaluations + - Compare models with fundamentally different capabilities (e.g., text vs vision) + + +## Next Steps + +- [Pairwise Evaluation](select_rank.md) — Compare models with pre-existing test data +- [Refine Data Quality](data_refinement.md) — Use grader feedback to improve outputs +- [Create Custom Graders](../building_graders/create_custom_graders.md) — Build specialized evaluation criteria +- [Run Grading Tasks](../running_graders/run_tasks.md) — Scale evaluations with GradingRunner + + diff --git a/mkdocs.yml b/mkdocs.yml index 15282c1f..87d35bc7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,6 +46,7 @@ nav: - Validate on RewardBench2: validating_graders/rewardbench2.md - Applications: + - Zero-Shot Evaluation: applications/zero_shot_evaluation.md - Refine Data Quality: applications/data_refinement.md - Pairwise Model Evaluation: applications/select_rank.md diff --git a/openjudge/analyzer/__init__.py b/openjudge/analyzer/__init__.py index bdf4e1e8..06ca95a0 100644 --- a/openjudge/analyzer/__init__.py +++ b/openjudge/analyzer/__init__.py @@ -1,9 +1,27 @@ # -*- coding: utf-8 -*- -"""Analyzer module for computing aggregated results from evaluator outputs.""" +"""Analyzer module for computing aggregated results from evaluator outputs. -from .base_analyzer import AnalysisResult, BaseAnalyzer +This module provides analyzers for processing evaluation results and +computing aggregated metrics, statistics, and insights. + +Classes: + AnalysisResult: Base class for analyzer results + BaseAnalyzer: Abstract base class for analyzers + PairwiseAnalysisResult: Result of pairwise comparison analysis + PairwiseAnalyzer: Analyzer for pairwise comparison results +""" + +from openjudge.analyzer.base_analyzer import AnalysisResult, BaseAnalyzer +from openjudge.analyzer.pairwise_analyzer import ( + PairwiseAnalysisResult, + PairwiseAnalyzer, +) __all__ = [ + # Base classes "AnalysisResult", "BaseAnalyzer", + # Pairwise analyzer + "PairwiseAnalysisResult", + "PairwiseAnalyzer", ] diff --git a/openjudge/analyzer/pairwise_analyzer.py b/openjudge/analyzer/pairwise_analyzer.py new file mode 100644 index 00000000..36964025 --- /dev/null +++ b/openjudge/analyzer/pairwise_analyzer.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +"""Pairwise comparison analyzer for computing win rates and rankings. + +This module provides the PairwiseAnalyzer class for analyzing pairwise +comparison results from LLM-based evaluations, computing win rates, +win matrices, and model rankings. + +Classes: + PairwiseAnalysisResult: Result of pairwise comparison analysis + PairwiseAnalyzer: Analyzer for pairwise comparison results +""" + +from typing import Any, Dict, List, Tuple + +from pydantic import Field + +from openjudge.analyzer.base_analyzer import AnalysisResult, BaseAnalyzer +from openjudge.graders.schema import GraderResult, GraderScore + + +class PairwiseAnalysisResult(AnalysisResult): + """Result of pairwise comparison analysis. + + Attributes: + win_rates: Win rate for each model (0.0 to 1.0) + win_matrix: Win rate matrix where win_matrix[A][B] = how often A beats B + rankings: Model rankings sorted by win rate (descending) + total_comparisons: Total number of pairwise comparisons + best_model: Model with highest win rate + worst_model: Model with lowest win rate + """ + + win_rates: Dict[str, float] = Field( + default_factory=dict, + description="Win rate for each model (0.0 to 1.0)", + ) + win_matrix: Dict[str, Dict[str, float]] = Field( + default_factory=dict, + description="Win rate matrix: win_matrix[model_a][model_b] = how often A beats B", + ) + rankings: List[Tuple[str, float]] = Field( + default_factory=list, + description="Model rankings sorted by win rate", + ) + total_comparisons: int = Field(default=0, description="Total number of pairwise comparisons") + best_model: str = Field(default="", description="Model with highest win rate") + worst_model: str = Field(default="", description="Model with lowest win rate") + + +class PairwiseAnalyzer(BaseAnalyzer): + """Analyzer for pairwise comparison results. + + This analyzer computes win rates and rankings from pairwise comparison results. + It processes the results from pairwise LLM evaluations where each comparison + yields a score indicating which response is better. + + The analyzer expects dataset samples to contain metadata with 'model_a' and + 'model_b' keys indicating which models produced the compared responses. + + Attributes: + name: Name of the analyzer + model_names: List of all model names being compared + + Example: + >>> analyzer = PairwiseAnalyzer(model_names=["gpt-4", "claude", "gemini"]) + >>> result = analyzer.analyze(dataset, grader_results) + >>> print(f"Best model: {result.best_model}") + >>> print(f"Rankings: {result.rankings}") + """ + + name: str = "Pairwise Win Rate Analysis" + + def __init__(self, model_names: List[str]): + """Initialize PairwiseAnalyzer. + + Args: + model_names: List of all model names being compared + """ + self.model_names = model_names + + def _initialize_model_matrix(self) -> Dict[str, Dict[str, int]]: + """Initialize nested dictionary for model comparison counts.""" + return {m: {n: 0 for n in self.model_names if n != m} for m in self.model_names} + + def analyze( + self, + dataset: List[dict], + grader_results: List[GraderResult], + **kwargs: Any, + ) -> PairwiseAnalysisResult: + """Analyze pairwise comparison results and compute win rates. + + This method processes the grader results from pairwise comparisons, + counting wins for each model and computing win rates and rankings. + + The score interpretation: + - score >= 0.5: model_a wins + - score < 0.5: model_b wins + + Args: + dataset: List of pairwise comparison samples. Each sample should have + a 'metadata' dict containing 'model_a' and 'model_b' keys. + grader_results: Grader results with scores (1.0 for first wins, + 0.0 for second wins, or values in between) + **kwargs: Additional parameters (unused) + + Returns: + PairwiseAnalysisResult with win rates, win matrix, and rankings + + Example: + >>> # Dataset format + >>> dataset = [ + ... {"metadata": {"model_a": "gpt-4", "model_b": "claude", "order": "original"}}, + ... {"metadata": {"model_a": "claude", "model_b": "gpt-4", "order": "swapped"}}, + ... ] + >>> # GraderScore with score=1.0 means first model (model_a) wins + >>> results = [GraderScore(name="pairwise", score=1.0, reason="..."), ...] + >>> analyzer = PairwiseAnalyzer(model_names=["gpt-4", "claude"]) + >>> analysis = analyzer.analyze(dataset, results) + """ + # Pre-extract all metadata to avoid repeated dict lookups + metadata_list = [sample.get("metadata", {}) for sample in dataset] + + # Initialize win counts (use integers for counting) + win_counts = self._initialize_model_matrix() + comparison_counts = self._initialize_model_matrix() + + # Use zip to pair results with metadata in one pass + for metadata, result in zip(metadata_list, grader_results): + model_a = metadata.get("model_a") + model_b = metadata.get("model_b") + + if not model_a or not model_b or not isinstance(result, GraderScore): + continue + + # score >= 0.5 means model_a wins, otherwise model_b wins + if result.score >= 0.5: + win_counts[model_a][model_b] += 1 + else: + win_counts[model_b][model_a] += 1 + + # Both models participated in this comparison + comparison_counts[model_a][model_b] += 1 + comparison_counts[model_b][model_a] += 1 + + # Calculate win matrix in single comprehension + win_matrix = { + model_a: { + model_b: ( + win_counts[model_a][model_b] / comparison_counts[model_a][model_b] + if comparison_counts[model_a][model_b] > 0 + else 0.0 + ) + for model_b in self.model_names + if model_a != model_b + } + for model_a in self.model_names + } + + # Calculate win rates using comprehension + win_rates = { + model: ( + sum(win_counts[model].values()) / sum(comparison_counts[model].values()) + if sum(comparison_counts[model].values()) > 0 + else 0.0 + ) + for model in self.model_names + } + + # Sort by win rate (descending) + rankings = sorted(win_rates.items(), key=lambda x: x[1], reverse=True) + + return PairwiseAnalysisResult( + name=self.name, + win_rates=win_rates, + win_matrix=win_matrix, + rankings=rankings, + total_comparisons=len(grader_results), + best_model=rankings[0][0] if rankings else "", + worst_model=rankings[-1][0] if rankings else "", + metadata={ + "num_models": len(self.model_names), + "explanation": ( + f"Analyzed {len(grader_results)} pairwise comparisons " f"across {len(self.model_names)} models" + ), + }, + ) diff --git a/openjudge/generator/__init__.py b/openjudge/generator/__init__.py index e69de29b..0e3da8c3 100644 --- a/openjudge/generator/__init__.py +++ b/openjudge/generator/__init__.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +"""Generator module for creating graders and evaluation rubrics. + +This module provides generators for automatically creating graders and +evaluation criteria based on data or task descriptions. + +Submodules: + simple_rubric: Task-description-based rubric generation (zero-shot) + iterative_rubric: Preference-data-based rubric generation (iterative refinement) + +Classes: + BaseGraderGenerator: Abstract base class for grader generators + GraderGeneratorConfig: Configuration for grader generation + LLMGraderGenerator: Base class for LLM-based grader generators + LLMGraderGeneratorConfig: Configuration for LLM grader generation + + # Simple rubric generation (from task description) + SimpleRubricsGenerator: Main generator for simple rubric-based graders + SimpleRubricsGeneratorConfig: Configuration for simple rubric generation + TaskBasedRubricGenerator: Core rubric generation logic + +Constants: + DEFAULT_RUBRICS: Default fallback rubrics if generation fails +""" + +from openjudge.generator.base_generator import ( + BaseGraderGenerator, + GraderGeneratorConfig, +) +from openjudge.generator.llm_grader_generator import ( + LLMGraderGenerator, + LLMGraderGeneratorConfig, +) + +# Simple rubric generation +from openjudge.generator.simple_rubric import ( + DEFAULT_RUBRICS, + SimpleRubricsGenerator, + SimpleRubricsGeneratorConfig, + TaskBasedRubricGenerator, +) + +__all__ = [ + # Base classes + "BaseGraderGenerator", + "GraderGeneratorConfig", + "LLMGraderGenerator", + "LLMGraderGeneratorConfig", + # Simple rubric generation + "SimpleRubricsGenerator", + "SimpleRubricsGeneratorConfig", + "TaskBasedRubricGenerator", + "DEFAULT_RUBRICS", +] diff --git a/openjudge/generator/simple_rubric/__init__.py b/openjudge/generator/simple_rubric/__init__.py new file mode 100644 index 00000000..2378311e --- /dev/null +++ b/openjudge/generator/simple_rubric/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +"""Simple rubric generator module for automatic evaluation criteria generation. + +This module provides a simple, task-description-based approach to generating +evaluation rubrics. It generates rubrics from task descriptions and sample +queries, without requiring labeled training data. + +This is in contrast to the iterative_rubric module which learns rubrics from +preference data through an iterative refinement process. + +Classes: + SimpleRubricsGenerator: Main generator class that creates LLMGrader instances + SimpleRubricsGeneratorConfig: Configuration for the generator + TaskBasedRubricGenerator: Core rubric generation logic + +Constants: + DEFAULT_RUBRICS: Default fallback rubrics if generation fails +""" + +from openjudge.generator.simple_rubric.generator import ( + SimpleRubricsGenerator, + SimpleRubricsGeneratorConfig, +) +from openjudge.generator.simple_rubric.rubric_generator import ( + DEFAULT_RUBRICS, + TaskBasedRubricGenerator, +) + +__all__ = [ + # Main generator (creates LLMGrader) + "SimpleRubricsGenerator", + "SimpleRubricsGeneratorConfig", + # Core rubric generation logic + "TaskBasedRubricGenerator", + "DEFAULT_RUBRICS", +] diff --git a/openjudge/generator/simple_rubric/generator.py b/openjudge/generator/simple_rubric/generator.py new file mode 100644 index 00000000..f9067ecc --- /dev/null +++ b/openjudge/generator/simple_rubric/generator.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +"""Simple rubrics generator implementation. + +This module implements a task-description-based approach to generating +evaluation rubrics. It creates LLMGrader instances with rubrics generated +from task descriptions and sample queries. + +This is a simpler alternative to the iterative_rubric module, which learns +rubrics from preference data through an iterative refinement process. + +Usage: + >>> from openjudge.generator.simple_rubric import SimpleRubricsGenerator, SimpleRubricsGeneratorConfig + >>> from openjudge.models.openai_chat_model import OpenAIChatModel + >>> + >>> config = SimpleRubricsGeneratorConfig( + ... grader_name="Medical QA Grader", + ... model=OpenAIChatModel(model="gpt-4o-mini"), + ... task_description="Medical question answering system", + ... scenario="Healthcare professionals seeking quick answers" + ... ) + >>> generator = SimpleRubricsGenerator(config) + >>> grader = await generator.generate(dataset=[], sample_queries=["What are the symptoms of flu?"]) +""" + +from dataclasses import dataclass, field +from typing import List, Optional + +from loguru import logger + +from openjudge.generator.iterative_rubric.query_rubric_generator import ( + LISTWISE_EVALUATION_TEMPLATE, + POINTWISE_EVALUATION_TEMPLATE, +) +from openjudge.generator.llm_grader_generator import ( + LLMGraderGenerator, + LLMGraderGeneratorConfig, +) +from openjudge.generator.simple_rubric.rubric_generator import TaskBasedRubricGenerator +from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.schema import GraderMode +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.prompt_template import LanguageEnum + + +@dataclass +class SimpleRubricsGeneratorConfig(LLMGraderGeneratorConfig): + """Configuration for simple rubrics generator. + + This configuration extends LLMGraderGeneratorConfig with parameters + specific to task-description-based rubric generation. + + Attributes: + task_description: Description of the task for evaluation. + scenario: Optional usage scenario for context. + language: Language for prompts (ZH or EN). Defaults to EN. + default_rubrics: Fallback rubrics if generation fails. + max_retries: Maximum number of retry attempts for LLM calls. + min_score: Minimum score for pointwise evaluation. + max_score: Maximum score for pointwise evaluation. + + Inherited from LLMGraderGeneratorConfig: + grader_name: Human-readable name for the generated grader. + model: Language model to use for generation. + grader_mode: Mode for the generated grader (POINTWISE or LISTWISE). + custom_evaluation_prompt: Custom template for evaluation. + """ + + task_description: str = "" + scenario: Optional[str] = None + language: LanguageEnum = LanguageEnum.EN + default_rubrics: List[str] = field(default_factory=list) + max_retries: int = 3 + min_score: int = 0 + max_score: int = 1 + + def __post_init__(self): + """Process model configuration if provided as dict.""" + if isinstance(self.model, dict): + self.model = OpenAIChatModel(**self.model) + + +class SimpleRubricsGenerator(LLMGraderGenerator): + """Generator for creating LLM-based graders with task-description-based rubrics. + + This generator implements a simple approach to rubric generation: + 1. Takes a task description and optional sample queries + 2. Uses an LLM to generate relevant evaluation criteria + 3. Creates an LLMGrader configured with these rubrics + + Example: + >>> config = SimpleRubricsGeneratorConfig( + ... grader_name="Medical QA Grader", + ... model=OpenAIChatModel(model="gpt-4o-mini"), + ... task_description="Medical question answering system", + ... scenario="Healthcare professionals seeking quick answers" + ... ) + >>> generator = SimpleRubricsGenerator(config) + >>> grader = await generator.generate( + ... dataset=[], + ... sample_queries=["What are the symptoms of flu?"] + ... ) + """ + + def __init__(self, config: SimpleRubricsGeneratorConfig) -> None: + """Initialize the simple rubrics generator. + + Args: + config: Configuration for rubric generation. + """ + super().__init__(config) + self.config: SimpleRubricsGeneratorConfig = config + + self._rubric_generator = TaskBasedRubricGenerator( + model=config.model, + task_description=config.task_description, + scenario=config.scenario, + language=config.language, + default_rubrics=config.default_rubrics, + max_retries=config.max_retries, + ) + + async def generate( + self, + dataset: List[dict], + sample_queries: Optional[List[str]] = None, + **kwargs, + ) -> LLMGrader: + """Generate an LLMGrader with rubrics from task description. + + Args: + dataset: List of data dictionaries (used to extract sample queries + if sample_queries is not provided). + sample_queries: Optional list of sample queries for context. + **kwargs: Additional arguments (currently unused). + + Returns: + LLMGrader: Configured grader instance with generated rubrics. + """ + if sample_queries is None and dataset: + sample_queries = [d.get("query", "") for d in dataset[:5] if d.get("query")] + + rubrics = await self._generate_rubrics(sample_queries) + + grader_kwargs = { + "name": self.config.grader_name, + "model": self.config.model, + "mode": self.config.grader_mode, + "rubrics": rubrics, + "language": self.config.language, + } + + if self.config.grader_mode == GraderMode.POINTWISE: + grader_kwargs["min_score"] = self.config.min_score + grader_kwargs["max_score"] = self.config.max_score + + if self.config.custom_evaluation_prompt is not None: + grader_kwargs["template"] = self.config.custom_evaluation_prompt + else: + if self.config.grader_mode == GraderMode.POINTWISE: + grader_kwargs["template"] = POINTWISE_EVALUATION_TEMPLATE + else: + grader_kwargs["template"] = LISTWISE_EVALUATION_TEMPLATE + + return LLMGrader(**grader_kwargs) + + async def _generate_rubrics( + self, + dataset: Optional[List[str]] = None, # pylint: disable=arguments-renamed + ) -> str: + """Generate rubrics from task description. + + Args: + dataset: Optional list of sample queries for context. + + Returns: + str: Formatted string containing evaluation rubrics. + """ + rubrics_list = await self._rubric_generator.generate(sample_queries=dataset) + + formatted_rubrics = "\n\n".join([f"{i + 1}. {rubric}" for i, rubric in enumerate(rubrics_list)]) + + logger.info(f"Generated {len(rubrics_list)} rubrics from task description") + + return formatted_rubrics diff --git a/openjudge/generator/simple_rubric/rubric_generator.py b/openjudge/generator/simple_rubric/rubric_generator.py new file mode 100644 index 00000000..0e22557b --- /dev/null +++ b/openjudge/generator/simple_rubric/rubric_generator.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +"""Task-based rubric generator for automatic evaluation criteria generation. + +This module provides functionality to automatically generate evaluation rubrics +based on task descriptions, enabling zero-shot evaluation pipelines. + +The generator uses an LLM to analyze the task description and sample queries +to produce relevant evaluation criteria without requiring labeled training data. + +Classes: + TaskBasedRubricGenerator: Generator for evaluation rubrics. +""" + +from typing import List, Optional + +from loguru import logger +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_fixed + +from openjudge.models.base_chat_model import BaseChatModel +from openjudge.models.schema.oai.message import ChatMessage +from openjudge.models.schema.prompt_template import LanguageEnum, PromptTemplate + +# ============================================================================= +# Constants +# ============================================================================= + +DEFAULT_RUBRICS: List[str] = [ + "Accuracy: Whether the response is factually correct", + "Relevance: Whether the response addresses the query", + "Completeness: Whether the response is comprehensive", +] + +# ============================================================================= +# Prompt Templates +# ============================================================================= + +RUBRIC_GENERATION_PROMPT_EN = """# Task +Generate evaluation rubrics for pairwise comparison of model responses. + +## Task Description +{task_description} + +## Scenario +{scenario} + +## Sample Queries (for context) +{sample_queries} + +## Requirements +- Generate 3-5 clear evaluation criteria for comparing two responses +- Each criterion should be objective and measurable +- Criteria should be relevant to the task and scenario +- Focus on aspects that distinguish good responses from poor ones + +## Output Format +Return a JSON object with: +- rubrics: list of evaluation criteria strings +- reason: brief explanation of why these criteria are important + +Example: +{{ + "rubrics": [ + "Accuracy: Whether the response contains correct and factual information", + "Completeness: Whether the response fully addresses the query", + "Clarity: Whether the response is well-organized and easy to understand" + ], + "reason": "These criteria capture the key aspects for evaluating..." +}} +""" + +RUBRIC_GENERATION_PROMPT_ZH = """# 任务 +为模型回答的成对比较生成评估标准。 + +## 任务描述 +{task_description} + +## 使用场景 +{scenario} + +## 示例查询(用于上下文理解) +{sample_queries} + +## 要求 +- 生成3-5个清晰的评估标准用于比较两个回答 +- 每个标准应该客观且可测量 +- 标准应与任务和场景相关 +- 聚焦于能够区分好回答和差回答的方面 + +## 输出格式 +返回一个JSON对象,包含: +- rubrics: 评估标准字符串列表 +- reason: 简要解释为什么这些标准是重要的 + +示例: +{{ + "rubrics": [ + "准确性:回答是否包含正确和真实的信息", + "完整性:回答是否完整地解决了问题", + "清晰度:回答是否组织良好、易于理解" + ], + "reason": "这些标准捕捉了评估的关键方面..." +}} +""" + +RUBRIC_GENERATION_TEMPLATE = PromptTemplate( + messages={ + LanguageEnum.EN: [ + ChatMessage( + role="system", + content="You are an expert at designing evaluation criteria for AI systems.", + ), + ChatMessage(role="user", content=RUBRIC_GENERATION_PROMPT_EN), + ], + LanguageEnum.ZH: [ + ChatMessage( + role="system", + content="你是一位设计AI系统评估标准的专家。", + ), + ChatMessage(role="user", content=RUBRIC_GENERATION_PROMPT_ZH), + ], + }, +) + + +# ============================================================================= +# Output Schema +# ============================================================================= + + +class RubricGenerationOutput(BaseModel): + """Output schema for rubric generation.""" + + rubrics: List[str] = Field(..., description="List of evaluation rubrics") + reason: str = Field(default="", description="Reasoning for these rubrics") + + +# ============================================================================= +# TaskBasedRubricGenerator +# ============================================================================= + + +class TaskBasedRubricGenerator: + """Generate evaluation rubrics based on task description. + + This generator creates evaluation rubrics that can be used for pairwise + comparison or other evaluation scenarios. It uses an LLM to generate + task-specific criteria based on the provided task description. + + Example: + >>> from openjudge.models.openai_chat_model import OpenAIChatModel + >>> from openjudge.generator.simple_rubric import TaskBasedRubricGenerator + >>> + >>> model = OpenAIChatModel(model="gpt-4o-mini") + >>> generator = TaskBasedRubricGenerator( + ... model=model, + ... task_description="Medical question answering system", + ... scenario="Healthcare professionals seeking quick answers" + ... ) + >>> rubrics = await generator.generate(sample_queries=["What are the symptoms of flu?"]) + """ + + def __init__( + self, + model: BaseChatModel, + task_description: str, + scenario: Optional[str] = None, + language: LanguageEnum = LanguageEnum.EN, + default_rubrics: Optional[List[str]] = None, + max_retries: int = 3, + ): + """Initialize TaskBasedRubricGenerator. + + Args: + model: Language model for generating rubrics. + task_description: Description of the task for evaluation. + scenario: Optional usage scenario for context. + language: Language for prompts (ZH or EN). Defaults to EN. + default_rubrics: Fallback rubrics if generation fails. + max_retries: Maximum number of retry attempts for LLM calls. + """ + self.model = model + self.task_description = task_description + self.scenario = scenario + self.language = language + self.default_rubrics = default_rubrics or DEFAULT_RUBRICS.copy() + self.max_retries = max_retries + + async def generate( + self, + sample_queries: Optional[List[str]] = None, + ) -> List[str]: + """Generate evaluation rubrics. + + Args: + sample_queries: Optional sample queries for context. + These help the LLM understand what kind of + queries will be evaluated. + + Returns: + List of rubric strings + """ + + @retry(stop=stop_after_attempt(self.max_retries), wait=wait_fixed(1.0)) + async def _generate() -> List[str]: + queries_text = "None provided" + if sample_queries: + queries_text = "\n".join(f"- {q}" for q in sample_queries[:5]) + + messages = RUBRIC_GENERATION_TEMPLATE.format( + task_description=self.task_description, + scenario=self.scenario or "General usage", + sample_queries=queries_text, + language=self.language, + ) + + response = await self.model.achat( + messages=list(messages), + structured_model=RubricGenerationOutput, + ) + + if not response.parsed or "rubrics" not in response.parsed: + raise ValueError("Failed to parse rubric generation response") + + return response.parsed["rubrics"] + + try: + rubrics = await _generate() + logger.info(f"Generated {len(rubrics)} evaluation rubrics") + for i, rubric in enumerate(rubrics, 1): + logger.debug(f" {i}. {rubric}") + return rubrics + except Exception as e: + logger.error(f"Rubric generation failed: {e}") + logger.warning("Using default rubrics as fallback") + return self.default_rubrics diff --git a/tests/generator/test_simple_rubric.py b/tests/generator/test_simple_rubric.py new file mode 100644 index 00000000..101ad9a9 --- /dev/null +++ b/tests/generator/test_simple_rubric.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +"""Simple Rubric Generator test module. + +This module contains unit tests for the Simple Rubric Generator functionality +which generates evaluation rubrics from task descriptions. + +Demonstrates workflow: +1. Create generator with task description configuration +2. Generate rubrics from task description (no labeled data required) +3. Optionally create a complete LLMGrader for evaluation + +Supports both TaskBasedRubricGenerator (rubrics only) and +SimpleRubricsGenerator (complete LLMGrader). + +Example: + Run all tests: + ```bash + pytest tests/generator/test_simple_rubric.py -v + ``` + + Run a specific test: + ```bash + pytest tests/generator/test_simple_rubric.py::test_task_based_rubric_generator -v + ``` + + Run directly as a script: + ```bash + python tests/generator/test_simple_rubric.py + ``` +""" + +import asyncio + +import pytest +from loguru import logger + +from openjudge.generator.simple_rubric import ( + SimpleRubricsGenerator, + SimpleRubricsGeneratorConfig, + TaskBasedRubricGenerator, +) +from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.schema import GraderMode, GraderScore +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.prompt_template import LanguageEnum + +# ============================================================================= +# Test Data +# ============================================================================= + +# Task description for testing +TEST_TASK_DESCRIPTION = "English to Chinese translation assistant that helps users translate technical documents" + +TEST_SCENARIO = "Users need to translate technical documentation from English to fluent, accurate Chinese" + +# Sample queries for context +TEST_SAMPLE_QUERIES = [ + "Translate this paragraph into Chinese: 'Machine learning is a subset of artificial intelligence.'", + "Translate the following technical term: 'neural network'", + "How would you translate 'API endpoint' into Chinese?", +] + +# Test data for evaluation +TEST_EVALUATION_DATA = { + "query": "Translate this sentence: 'The database query returned an error.'", + "response": "数据库查询返回了一个错误。", +} + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_test_model() -> OpenAIChatModel: + """Get test model instance. + + Returns: + OpenAIChatModel: Configured OpenAI chat model for testing. + """ + return OpenAIChatModel(model="qwen3-32b", stream=False) + + +# ============================================================================= +# TaskBasedRubricGenerator Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_task_based_rubric_generator() -> None: + """Test TaskBasedRubricGenerator for generating rubrics from task description.""" + model = get_test_model() + + generator = TaskBasedRubricGenerator( + model=model, + task_description=TEST_TASK_DESCRIPTION, + scenario=TEST_SCENARIO, + language=LanguageEnum.EN, + ) + rubrics = await generator.generate(sample_queries=TEST_SAMPLE_QUERIES) + + # Verify rubrics were generated + assert rubrics is not None, "Rubrics should not be None" + assert isinstance(rubrics, list), f"Rubrics should be a list, got {type(rubrics)}" + assert len(rubrics) > 0, "Rubrics list should not be empty" + + # Verify each rubric is a non-empty string + for i, rubric in enumerate(rubrics): + assert isinstance(rubric, str), f"Rubric {i} should be a string, got {type(rubric)}" + assert len(rubric) > 0, f"Rubric {i} should not be empty" + + logger.info(f"Generated {len(rubrics)} rubrics:") + for i, rubric in enumerate(rubrics, 1): + logger.info(f" {i}. {rubric}") + + +@pytest.mark.asyncio +async def test_task_based_rubric_generator_chinese() -> None: + """Test TaskBasedRubricGenerator with Chinese language prompts.""" + model = get_test_model() + + generator = TaskBasedRubricGenerator( + model=model, + task_description="代码审查助手,帮助开发者检查 Python 代码质量", + scenario="开发者需要对代码进行质量检查和改进建议", + language=LanguageEnum.ZH, + ) + rubrics = await generator.generate( + sample_queries=[ + "请审查这段代码是否有bug", + "这个函数的命名是否合理?", + ] + ) + + # Verify rubrics were generated + assert rubrics is not None, "Rubrics should not be None" + assert isinstance(rubrics, list), f"Rubrics should be a list, got {type(rubrics)}" + assert len(rubrics) > 0, "Rubrics list should not be empty" + + logger.info(f"Generated {len(rubrics)} Chinese rubrics:") + for i, rubric in enumerate(rubrics, 1): + logger.info(f" {i}. {rubric}") + + +@pytest.mark.asyncio +async def test_task_based_rubric_generator_default_fallback() -> None: + """Test that default rubrics are returned when generation fails.""" + model = get_test_model() + + default_rubrics = [ + "Custom default rubric 1", + "Custom default rubric 2", + ] + + generator = TaskBasedRubricGenerator( + model=model, + task_description=TEST_TASK_DESCRIPTION, + scenario=TEST_SCENARIO, + default_rubrics=default_rubrics, + ) + + # Verify default_rubrics are set + assert generator.default_rubrics == default_rubrics + + +# ============================================================================= +# SimpleRubricsGenerator Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_simple_rubrics_generator_pointwise() -> None: + """Test SimpleRubricsGenerator for creating a complete LLMGrader (pointwise mode).""" + model = get_test_model() + + config = SimpleRubricsGeneratorConfig( + grader_name="Translation_Quality_Grader", + model=model, + grader_mode=GraderMode.POINTWISE, + task_description=TEST_TASK_DESCRIPTION, + scenario=TEST_SCENARIO, + language=LanguageEnum.EN, + min_score=0, + max_score=5, + ) + + generator = SimpleRubricsGenerator(config) + grader = await generator.generate( + dataset=[], + sample_queries=TEST_SAMPLE_QUERIES, + ) + + # Verify grader was created + assert grader is not None, "Grader should not be None" + assert isinstance(grader, LLMGrader), f"Grader should be LLMGrader, got {type(grader)}" + assert grader.name == "Translation_Quality_Grader", f"Grader name mismatch" + + # Verify rubrics were generated + rubrics = grader.kwargs.get("rubrics") + assert rubrics is not None, "Rubrics key should exist in kwargs" + assert len(rubrics) > 0, "Rubrics should not be empty" + + logger.info(f"Generated rubrics:\n{rubrics}") + + # Evaluate test sample + result = await grader.aevaluate( + query=TEST_EVALUATION_DATA["query"], + response=TEST_EVALUATION_DATA["response"], + ) + + # Verify result structure + assert result is not None, "Evaluation result should not be None" + assert isinstance(result, GraderScore), f"Result should be GraderScore, got {type(result)}" + assert result.score is not None, "Score should not be None" + assert isinstance(result.score, (int, float)), f"Score should be numeric, got {type(result.score)}" + assert result.reason is not None, "Reason should not be None" + + logger.info(f"Pointwise evaluation result: {result}") + + +@pytest.mark.asyncio +async def test_simple_rubrics_generator_extract_queries_from_dataset() -> None: + """Test that SimpleRubricsGenerator extracts sample queries from dataset.""" + model = get_test_model() + + config = SimpleRubricsGeneratorConfig( + grader_name="Auto_Query_Extraction_Grader", + model=model, + task_description=TEST_TASK_DESCRIPTION, + scenario=TEST_SCENARIO, + ) + + # Provide dataset with queries but no explicit sample_queries + dataset = [ + {"query": "Translate: Hello world", "response": "你好世界"}, + {"query": "Translate: Good morning", "response": "早上好"}, + {"query": "Translate: Thank you", "response": "谢谢"}, + ] + + generator = SimpleRubricsGenerator(config) + grader = await generator.generate(dataset=dataset) # No sample_queries provided + + # Verify grader was created + assert grader is not None, "Grader should not be None" + assert isinstance(grader, LLMGrader), f"Grader should be LLMGrader, got {type(grader)}" + + # Verify rubrics were generated (queries should be extracted from dataset) + rubrics = grader.kwargs.get("rubrics") + assert rubrics is not None, "Rubrics key should exist in kwargs" + assert len(rubrics) > 0, "Rubrics should not be empty" + + logger.info(f"Generated rubrics from dataset queries:\n{rubrics}") + + +@pytest.mark.asyncio +async def test_simple_rubrics_generator_with_model_dict() -> None: + """Test SimpleRubricsGenerator with model configuration as dictionary.""" + config = SimpleRubricsGeneratorConfig( + grader_name="Dict_Config_Grader", + model={"model": "qwen3-32b", "stream": False}, # Dict instead of model instance + task_description=TEST_TASK_DESCRIPTION, + scenario=TEST_SCENARIO, + ) + + generator = SimpleRubricsGenerator(config) + + # Verify model was converted from dict to OpenAIChatModel + assert generator.config.model is not None + assert isinstance(generator.config.model, OpenAIChatModel) + + grader = await generator.generate(dataset=[], sample_queries=TEST_SAMPLE_QUERIES) + + assert grader is not None, "Grader should not be None" + assert isinstance(grader, LLMGrader), f"Grader should be LLMGrader, got {type(grader)}" + + logger.info("Successfully created grader with dict model config") + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + + +async def main() -> None: + """Run all test functions.""" + logger.info("Running TaskBasedRubricGenerator tests...") + await test_task_based_rubric_generator() + await test_task_based_rubric_generator_chinese() + await test_task_based_rubric_generator_default_fallback() + + logger.info("\nRunning SimpleRubricsGenerator tests...") + await test_simple_rubrics_generator_pointwise() + await test_simple_rubrics_generator_extract_queries_from_dataset() + await test_simple_rubrics_generator_with_model_dict() + + logger.info("\nAll tests passed!") + + +if __name__ == "__main__": + asyncio.run(main())