diff --git a/.gitignore b/.gitignore index 0549d46a..a4dfe0da 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,16 @@ qa_pairs* Khauneesh/ *job_args* +# Generated data files +freeform_data_*.json +row_data_*.json +lending_*.json +seeds_*.json +SeedsInstructions.json +*_example.json +nm.json +french_input.json + # DB *metadata.db-shm *metadata.db-wal diff --git a/app/client/src/pages/DataGenerator/Finish.tsx b/app/client/src/pages/DataGenerator/Finish.tsx index b54e5fee..60f7d4de 100644 --- a/app/client/src/pages/DataGenerator/Finish.tsx +++ b/app/client/src/pages/DataGenerator/Finish.tsx @@ -255,7 +255,8 @@ const Finish = () => { title: 'Review Dataset', description: 'Review your dataset to ensure it properly fits your usecase.', icon: , - href: getFilesURL(genDatasetResp?.export_path?.local || "") + href: getFilesURL(genDatasetResp?.export_path?.local || ""), + external: true }, { avatar: '', @@ -278,7 +279,8 @@ const Finish = () => { title: 'Review Dataset', description: 'Once your dataset finishes generating, you can review your dataset in the workbench files', icon: , - href: getFilesURL('') + href: getFilesURL(''), + external: true }, { avatar: '', @@ -361,7 +363,18 @@ const Finish = () => { ( + renderItem={({ title, href, icon, description, external }, i) => ( + external ? + + + } + title={title} + description={description} + /> + + : + props.theme.color}; - background-color: ${props => props.theme.backgroundColor}; - border: 1px solid ${props => props.theme.borderColor}; +const StyledTag = styled(Tag)<{ $theme: { color: string; backgroundColor: string; borderColor: string } }>` + color: ${props => props.$theme.color} !important; + background-color: ${props => props.$theme.backgroundColor} !important; + border: 1px solid ${props => props.$theme.borderColor} !important; `; @@ -150,7 +150,7 @@ const TemplateCard: React.FC = ({ template }) => { const { color, backgroundColor, borderColor } = getTemplateTagColors(theme as string); return ( - +
{tag}
diff --git a/app/core/model_handlers.py b/app/core/model_handlers.py index 92391b8c..b61a0ec7 100644 --- a/app/core/model_handlers.py +++ b/app/core/model_handlers.py @@ -180,6 +180,8 @@ def generate_response( return self._handle_caii_request(prompt) if self.inference_type == "openai": return self._handle_openai_request(prompt) + if self.inference_type == "openai_compatible": + return self._handle_openai_compatible_request(prompt) if self.inference_type == "gemini": return self._handle_gemini_request(prompt) raise ModelHandlerError(f"Unsupported inference_type={self.inference_type}", 400) @@ -342,6 +344,66 @@ def _handle_openai_request(self, prompt: str): except Exception as e: raise ModelHandlerError(f"OpenAI request failed: {e}", 500) + # ---------- OpenAI Compatible ------------------------------------------------------- + def _handle_openai_compatible_request(self, prompt: str): + """Handle OpenAI compatible endpoints with proper timeout configuration""" + try: + import httpx + from openai import OpenAI + + # Get API key from environment variable (only credential needed) + api_key = os.getenv('OpenAI_Endpoint_Compatible_Key') + if not api_key: + raise ModelHandlerError("OpenAI_Endpoint_Compatible_Key environment variable not set", 500) + + # Base URL comes from caii_endpoint parameter (passed during initialization) + openai_compatible_endpoint = self.caii_endpoint + if not openai_compatible_endpoint: + raise ModelHandlerError("OpenAI compatible endpoint not provided", 500) + + # Configure timeout for OpenAI compatible client (same as OpenAI v1.57.2) + timeout_config = httpx.Timeout( + connect=self.OPENAI_CONNECT_TIMEOUT, + read=self.OPENAI_READ_TIMEOUT, + write=10.0, + pool=5.0 + ) + + # Configure httpx client with certificate verification for private cloud + if os.path.exists("/etc/ssl/certs/ca-certificates.crt"): + http_client = httpx.Client( + verify="/etc/ssl/certs/ca-certificates.crt", + timeout=timeout_config + ) + else: + http_client = httpx.Client(timeout=timeout_config) + + # Remove trailing '/chat/completions' if present (similar to CAII handling) + openai_compatible_endpoint = openai_compatible_endpoint.removesuffix('/chat/completions') + + client = OpenAI( + api_key=api_key, + base_url=openai_compatible_endpoint, + http_client=http_client + ) + + completion = client.chat.completions.create( + model=self.model_id, + messages=[{"role": "user", "content": prompt}], + max_tokens=self.model_params.max_tokens, + temperature=self.model_params.temperature, + top_p=self.model_params.top_p, + stream=False, + ) + + print("generated via OpenAI Compatible endpoint") + response_text = completion.choices[0].message.content + + return self._extract_json_from_text(response_text) if not self.custom_p else response_text + + except Exception as e: + raise ModelHandlerError(f"OpenAI Compatible request failed: {str(e)}", status_code=500) + # ---------- Gemini ------------------------------------------------------- def _handle_gemini_request(self, prompt: str): if genai is None: diff --git a/app/main.py b/app/main.py index b81bc05c..2becafb8 100644 --- a/app/main.py +++ b/app/main.py @@ -42,8 +42,10 @@ sys.path.append(str(ROOT_DIR)) from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath, Technique from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService from app.services.export_results import Export_Service from app.core.prompt_templates import PromptBuilder, PromptHandler @@ -66,8 +68,10 @@ #****************************************Initialize************************************************ # Initialize services -synthesis_service = SynthesisService() -evaluator_service = EvaluatorService() +synthesis_service = SynthesisService() # Freeform only +synthesis_legacy_service = SynthesisLegacyService() # SFT and Custom_Workflow +evaluator_service = EvaluatorService() # Freeform only +evaluator_legacy_service = EvaluatorLegacyService() # SFT and Custom_Workflow export_service = Export_Service() db_manager = DatabaseManager() @@ -552,9 +556,11 @@ async def generate_examples(request: SynthesisRequest): if is_demo== True: if request.input_path: - return await synthesis_service.generate_result(request,is_demo, request_id=request_id) + # Custom_Workflow technique - route to legacy service + return await synthesis_legacy_service.generate_result(request,is_demo, request_id=request_id) else: - return await synthesis_service.generate_examples(request,is_demo, request_id=request_id) + # SFT technique - route to legacy service + return await synthesis_legacy_service.generate_examples(request,is_demo, request_id=request_id) else: return synthesis_job.generate_job(request, core, mem, request_id=request_id) @@ -626,7 +632,8 @@ async def evaluate_examples(request: EvaluationRequest): is_demo = request.is_demo if is_demo: - return evaluator_service.evaluate_results(request, request_id=request_id) + # SFT and Custom_Workflow evaluation - route to legacy service + return evaluator_legacy_service.evaluate_results(request, request_id=request_id) else: return synthesis_job.evaluate_job(request, request_id=request_id) @@ -1242,7 +1249,7 @@ def is_empty(self): async def health_check(): """Get API health status""" #return {"status": "healthy"} - return synthesis_service.get_health_check() + return synthesis_legacy_service.get_health_check() @app.get("/{use_case}/example_payloads") async def get_example_payloads(use_case:UseCase): @@ -1255,6 +1262,7 @@ async def get_example_payloads(use_case:UseCase): "technique": "sft", "topics": ["python_basics", "data_structures"], "is_demo": True, + "max_concurrent_topics": 5, "examples": [ { "question": "How do you create a list in Python and add elements to it?", @@ -1281,6 +1289,7 @@ async def get_example_payloads(use_case:UseCase): "technique": "sft", "topics": ["basic_queries", "joins"], "is_demo": True, + "max_concurrent_topics": 5, "schema": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100), email VARCHAR(255));\nCREATE TABLE orders (id INT PRIMARY KEY, user_id INT, amount DECIMAL(10,2), FOREIGN KEY (user_id) REFERENCES users(id));", "examples":[ { @@ -1309,6 +1318,7 @@ async def get_example_payloads(use_case:UseCase): "topics": ["topic 1", "topic 2"], "custom_prompt": "Give your instructions here", "is_demo": True, + "max_concurrent_topics": 5, "examples":[ { diff --git a/app/models/request_models.py b/app/models/request_models.py index eaf9d7e0..bf22343f 100644 --- a/app/models/request_models.py +++ b/app/models/request_models.py @@ -123,6 +123,7 @@ class SynthesisRequest(BaseModel): # Optional fields that can override defaults inference_type: Optional[str] = "aws_bedrock" caii_endpoint: Optional[str] = None + openai_compatible_endpoint: Optional[str] = None topics: Optional[List[str]] = None doc_paths: Optional[List[str]] = None input_path: Optional[List[str]] = None @@ -137,7 +138,13 @@ class SynthesisRequest(BaseModel): example_path: Optional[str] = None schema: Optional[str] = None # Added schema field custom_prompt: Optional[str] = None - display_name: Optional[str] = None + display_name: Optional[str] = None + max_concurrent_topics: Optional[int] = Field( + default=5, + ge=1, + le=100, + description="Maximum number of concurrent topics to process (1-100)" + ) # Optional model parameters with defaults model_params: Optional[ModelParameters] = Field( @@ -155,7 +162,7 @@ class SynthesisRequest(BaseModel): "technique": "sft", "topics": ["python_basics", "data_structures"], "is_demo": True, - + "max_concurrent_topics": 5 } } @@ -208,6 +215,12 @@ class EvaluationRequest(BaseModel): display_name: Optional[str] = None output_key: Optional[str] = 'Prompt' output_value: Optional[str] = 'Completion' + max_workers: Optional[int] = Field( + default=4, + ge=1, + le=100, + description="Maximum number of worker threads for parallel evaluation (1-100)" + ) # Export configuration export_type: str = "local" # "local" or "s3" @@ -226,7 +239,8 @@ class EvaluationRequest(BaseModel): "inference_type": "aws_bedrock", "import_path": "qa_pairs_llama3-1-70b-instruct-v1:0_20241114_212837_test.json", "import_type": "local", - "export_type":"local" + "export_type":"local", + "max_workers": 4 } } diff --git a/app/run_eval_job.py b/app/run_eval_job.py index 9591cee4..7a60714d 100644 --- a/app/run_eval_job.py +++ b/app/run_eval_job.py @@ -31,6 +31,7 @@ from app.models.request_models import EvaluationRequest, ModelParameters from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService import asyncio import nest_asyncio @@ -40,7 +41,7 @@ async def run_eval(request, job_name, request_id): try: - job = EvaluatorService() + job = EvaluatorLegacyService() result = job.evaluate_results(request,job_name, is_demo=False, request_id=request_id) return result except Exception as e: diff --git a/app/run_job.py b/app/run_job.py index 6d15833a..213c6d0a 100644 --- a/app/run_job.py +++ b/app/run_job.py @@ -32,6 +32,7 @@ import json from app.models.request_models import SynthesisRequest from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService import asyncio import nest_asyncio # Add this import @@ -41,7 +42,7 @@ async def run_synthesis(request, job_name, request_id): """Run standard synthesis job for question-answer pairs""" try: - job = SynthesisService() + job = SynthesisLegacyService() if request.input_path: result = await job.generate_result(request, job_name, is_demo=False, request_id=request_id) else: diff --git a/app/services/evaluator_legacy_service.py b/app/services/evaluator_legacy_service.py new file mode 100644 index 00000000..328c1f83 --- /dev/null +++ b/app/services/evaluator_legacy_service.py @@ -0,0 +1,409 @@ +import boto3 +from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from app.models.request_models import Example, ModelParameters, EvaluationRequest +from app.core.model_handlers import create_handler +from app.core.prompt_templates import PromptBuilder, PromptHandler +from app.services.aws_bedrock import get_bedrock_client +from app.core.database import DatabaseManager +from app.core.config import UseCase, Technique, get_model_family +from app.services.check_guardrail import ContentGuardrail +from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError +import os +from datetime import datetime, timezone +import json +import logging +from logging.handlers import RotatingFileHandler +from app.core.telemetry_integration import track_llm_operation +from functools import partial + +class EvaluatorLegacyService: + """Legacy service for evaluating generated QA pairs using Claude with parallel processing (SFT and Custom_Workflow only)""" + + def __init__(self, max_workers: int = 4): + self.bedrock_client = get_bedrock_client() + self.db = DatabaseManager() + self.max_workers = max_workers # Default max workers (configurable via request) + self.guard = ContentGuardrail() + self._setup_logging() + + def _setup_logging(self): + """Set up logging configuration""" + os.makedirs('logs', exist_ok=True) + + self.logger = logging.getLogger('evaluator_legacy_service') + self.logger.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + # File handler for general logs + file_handler = RotatingFileHandler( + 'logs/evaluator_legacy_service.log', + maxBytes=10*1024*1024, # 10MB + backupCount=5 + ) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + # File handler for errors + error_handler = RotatingFileHandler( + 'logs/evaluator_legacy_service_errors.log', + maxBytes=10*1024*1024, + backupCount=5 + ) + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter(formatter) + self.logger.addHandler(error_handler) + + + #@track_llm_operation("evaluate_single_pair") + def evaluate_single_pair(self, qa_pair: Dict, model_handler, request: EvaluationRequest, request_id=None) -> Dict: + """Evaluate a single QA pair""" + try: + # Default error response + error_response = { + request.output_key: qa_pair.get(request.output_key, "Unknown"), + request.output_value: qa_pair.get(request.output_value, "Unknown"), + "evaluation": { + "score": 0, + "justification": "Error during evaluation" + } + } + + try: + self.logger.info(f"Evaluating QA pair: {qa_pair.get(request.output_key, '')[:50]}...") + except Exception as e: + self.logger.error(f"Error logging QA pair: {str(e)}") + + try: + # Validate input qa_pair structure + if not all(key in qa_pair for key in [request.output_key, request.output_value]): + error_msg = "Missing required keys in qa_pair" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + prompt = PromptBuilder.build_eval_prompt( + request.model_id, + request.use_case, + qa_pair[request.output_key], + qa_pair[request.output_value], + request.examples, + request.custom_prompt + ) + #print(prompt) + except Exception as e: + error_msg = f"Error building evaluation prompt: {str(e)}" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + try: + response = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") + raise + except Exception as e: + error_msg = f"Error generating model response: {str(e)}" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + if not response: + error_msg = "Failed to parse model response" + self.logger.warning(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + try: + score = response[0].get('score', "no score key") + justification = response[0].get('justification', 'No justification provided') + if score== "no score key": + self.logger.info(f"Unsuccessful QA pair evaluation with score: {score}") + justification = "The evaluated pair did not generate valid score and justification" + score = 0 + else: + self.logger.info(f"Successfully evaluated QA pair with score: {score}") + + return { + "question": qa_pair[request.output_key], + "solution": qa_pair[request.output_value], + "evaluation": { + "score": score, + "justification": justification + } + } + except Exception as e: + error_msg = f"Error processing model response: {str(e)}" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + except ModelHandlerError: + raise + except Exception as e: + self.logger.error(f"Critical error in evaluate_single_pair: {str(e)}") + return error_response + + #@track_llm_operation("evaluate_topic") + def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, request: EvaluationRequest, request_id=None) -> Dict: + """Evaluate all QA pairs for a given topic in parallel""" + try: + self.logger.info(f"Starting evaluation for topic: {topic} with {len(qa_pairs)} QA pairs") + evaluated_pairs = [] + failed_pairs = [] + + try: + max_workers = request.max_workers or self.max_workers + with ThreadPoolExecutor(max_workers=max_workers) as executor: + try: + evaluate_func = partial( + self.evaluate_single_pair, + model_handler=model_handler, + request=request, request_id=request_id + ) + + future_to_pair = { + executor.submit(evaluate_func, pair): pair + for pair in qa_pairs + } + + for future in as_completed(future_to_pair): + try: + result = future.result() + evaluated_pairs.append(result) + except ModelHandlerError: + raise + except Exception as e: + error_msg = f"Error processing future result: {str(e)}" + self.logger.error(error_msg) + failed_pairs.append({ + "error": error_msg, + "pair": future_to_pair[future] + }) + + except Exception as e: + error_msg = f"Error in parallel execution: {str(e)}" + self.logger.error(error_msg) + raise + + except ModelHandlerError: + raise + except Exception as e: + error_msg = f"Error in ThreadPoolExecutor setup: {str(e)}" + self.logger.error(error_msg) + raise + + try: + # Calculate statistics only from successful evaluations + scores = [pair["evaluation"]["score"] for pair in evaluated_pairs if pair.get("evaluation", {}).get("score") is not None] + + if scores: + average_score = sum(scores) / len(scores) + average_score = round(average_score, 2) + min_score = min(scores) + max_score = max(scores) + else: + average_score = min_score = max_score = 0 + + topic_stats = { + "average_score": average_score, + "min_score": min_score, + "max_score": max_score, + "evaluated_pairs": evaluated_pairs, + "failed_pairs": failed_pairs, + "total_evaluated": len(evaluated_pairs), + "total_failed": len(failed_pairs) + } + + self.logger.info(f"Completed evaluation for topic: {topic}. Average score: {topic_stats['average_score']:.2f}") + return topic_stats + + except Exception as e: + error_msg = f"Error calculating topic statistics: {str(e)}" + self.logger.error(error_msg) + return { + "average_score": 0, + "min_score": 0, + "max_score": 0, + "evaluated_pairs": evaluated_pairs, + "failed_pairs": failed_pairs, + "error": error_msg + } + except ModelHandlerError: + raise + except Exception as e: + error_msg = f"Critical error in evaluate_topic: {str(e)}" + self.logger.error(error_msg) + return { + "average_score": 0, + "min_score": 0, + "max_score": 0, + "evaluated_pairs": [], + "failed_pairs": [], + "error": error_msg + } + + #@track_llm_operation("evaluate_results") + def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bool = True, request_id=None) -> Dict: + """Evaluate all QA pairs with parallel processing""" + try: + self.logger.info(f"Starting evaluation process - Demo Mode: {is_demo}") + + model_params = request.model_params or ModelParameters() + + self.logger.info(f"Creating model handler for model: {request.model_id}") + model_handler = create_handler( + request.model_id, + self.bedrock_client, + model_params=model_params, + inference_type = request.inference_type, + caii_endpoint = request.caii_endpoint + ) + + self.logger.info(f"Loading QA pairs from: {request.import_path}") + with open(request.import_path, 'r') as file: + data = json.load(file) + + evaluated_results = {} + all_scores = [] + + transformed_data = { + "results": {}, + } + for item in data: + topic = item.get('Seeds') + + # Create topic list if it doesn't exist + if topic not in transformed_data['results']: + transformed_data['results'][topic] = [] + + # Create QA pair + qa_pair = { + request.output_key: item.get(request.output_key, ''), # Use get() with default value + request.output_value: item.get(request.output_value, '') # Use get() with default value + } + + # Add to appropriate topic list + transformed_data['results'][topic].append(qa_pair) + + max_workers = request.max_workers or self.max_workers + self.logger.info(f"Processing {len(transformed_data['results'])} topics with {max_workers} workers") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_topic = { + executor.submit( + self.evaluate_topic, + topic, + qa_pairs, + model_handler, + request, request_id=request_id + ): topic + for topic, qa_pairs in transformed_data['results'].items() + } + + for future in as_completed(future_to_topic): + try: + topic = future_to_topic[future] + topic_stats = future.result() + evaluated_results[topic] = topic_stats + all_scores.extend([ + pair["evaluation"]["score"] + for pair in topic_stats["evaluated_pairs"] + ]) + except ModelHandlerError as e: + self.logger.error(f"ModelHandlerError in future processing: {str(e)}") + raise APIError(f"Model evaluation failed: {str(e)}") + + + overall_average = sum(all_scores) / len(all_scores) if all_scores else 0 + overall_average = round(overall_average, 2) + evaluated_results['Overall_Average'] = overall_average + + self.logger.info(f"Evaluation completed. Overall average score: {overall_average:.2f}") + + + timestamp = datetime.now(timezone.utc).isoformat() + time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] + model_name = get_model_family(request.model_id).split('.')[-1] + output_path = f"qa_pairs_{model_name}_{time_file}_evaluated.json" + + self.logger.info(f"Saving evaluation results to: {output_path}") + with open(output_path, 'w') as f: + json.dump(evaluated_results, f, indent=2) + + custom_prompt_str = PromptHandler.get_default_custom_eval_prompt( + request.use_case, + request.custom_prompt + ) + + + examples_value = ( + PromptHandler.get_default_eval_example(request.use_case, request.examples) + if hasattr(request, 'examples') + else None + ) + examples_str = self.safe_json_dumps(examples_value) + #print(examples_value, '\n',examples_str) + + metadata = { + 'timestamp': timestamp, + 'model_id': request.model_id, + 'inference_type': request.inference_type, + 'caii_endpoint':request.caii_endpoint, + 'use_case': request.use_case, + 'custom_prompt': custom_prompt_str, + 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, + 'generate_file_name': os.path.basename(request.import_path), + 'evaluate_file_name': os.path.basename(output_path), + 'display_name': request.display_name, + 'local_export_path': output_path, + 'examples': examples_str, + 'Overall_Average': overall_average + } + + self.logger.info("Saving evaluation metadata to database") + + + if is_demo: + self.db.save_evaluation_metadata(metadata) + return { + "status": "completed", + "result": evaluated_results, + "output_path": output_path + } + else: + + + job_status = "ENGINE_SUCCEEDED" + evaluate_file_name = os.path.basename(output_path) + self.db.update_job_evaluate(job_name, evaluate_file_name, output_path, timestamp, overall_average, job_status) + self.db.backup_and_restore_db() + return { + "status": "completed", + "output_path": output_path + } + except APIError: + raise + except ModelHandlerError as e: + # Add this specific handler + self.logger.error(f"ModelHandlerError in evaluation: {str(e)}") + raise APIError(str(e)) + except Exception as e: + error_msg = f"Error in evaluation process: {str(e)}" + self.logger.error(error_msg, exc_info=True) + if is_demo: + raise APIError(str(e)) + else: + time_stamp = datetime.now(timezone.utc).isoformat() + job_status = "ENGINE_FAILED" + file_name = '' + output_path = '' + overall_average = '' + self.db.update_job_evaluate(job_name,file_name, output_path, time_stamp, job_status) + + raise + + def safe_json_dumps(self, value): + """Convert value to JSON string only if it's not None""" + return json.dumps(value) if value is not None else None diff --git a/app/services/evaluator_service.py b/app/services/evaluator_service.py index 2b094b15..f3fbbcad 100644 --- a/app/services/evaluator_service.py +++ b/app/services/evaluator_service.py @@ -19,12 +19,12 @@ from functools import partial class EvaluatorService: - """Service for evaluating generated QA pairs using Claude with parallel processing""" + """Service for evaluating freeform data rows using Claude with parallel processing (Freeform technique only)""" - def __init__(self, max_workers: int = 4): + def __init__(self, max_workers: int = 5): self.bedrock_client = get_bedrock_client() self.db = DatabaseManager() - self.max_workers = max_workers + self.max_workers = max_workers # Default max workers (configurable via request) self.guard = ContentGuardrail() self._setup_logging() @@ -56,351 +56,6 @@ def _setup_logging(self): error_handler.setFormatter(formatter) self.logger.addHandler(error_handler) - - #@track_llm_operation("evaluate_single_pair") - def evaluate_single_pair(self, qa_pair: Dict, model_handler, request: EvaluationRequest, request_id=None) -> Dict: - """Evaluate a single QA pair""" - try: - # Default error response - error_response = { - request.output_key: qa_pair.get(request.output_key, "Unknown"), - request.output_value: qa_pair.get(request.output_value, "Unknown"), - "evaluation": { - "score": 0, - "justification": "Error during evaluation" - } - } - - try: - self.logger.info(f"Evaluating QA pair: {qa_pair.get(request.output_key, '')[:50]}...") - except Exception as e: - self.logger.error(f"Error logging QA pair: {str(e)}") - - try: - # Validate input qa_pair structure - if not all(key in qa_pair for key in [request.output_key, request.output_value]): - error_msg = "Missing required keys in qa_pair" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - prompt = PromptBuilder.build_eval_prompt( - request.model_id, - request.use_case, - qa_pair[request.output_key], - qa_pair[request.output_value], - request.examples, - request.custom_prompt - ) - #print(prompt) - except Exception as e: - error_msg = f"Error building evaluation prompt: {str(e)}" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - try: - response = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") - raise - except Exception as e: - error_msg = f"Error generating model response: {str(e)}" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - if not response: - error_msg = "Failed to parse model response" - self.logger.warning(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - try: - score = response[0].get('score', "no score key") - justification = response[0].get('justification', 'No justification provided') - if score== "no score key": - self.logger.info(f"Unsuccessful QA pair evaluation with score: {score}") - justification = "The evaluated pair did not generate valid score and justification" - score = 0 - else: - self.logger.info(f"Successfully evaluated QA pair with score: {score}") - - return { - "question": qa_pair[request.output_key], - "solution": qa_pair[request.output_value], - "evaluation": { - "score": score, - "justification": justification - } - } - except Exception as e: - error_msg = f"Error processing model response: {str(e)}" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - except ModelHandlerError: - raise - except Exception as e: - self.logger.error(f"Critical error in evaluate_single_pair: {str(e)}") - return error_response - - #@track_llm_operation("evaluate_topic") - def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, request: EvaluationRequest, request_id=None) -> Dict: - """Evaluate all QA pairs for a given topic in parallel""" - try: - self.logger.info(f"Starting evaluation for topic: {topic} with {len(qa_pairs)} QA pairs") - evaluated_pairs = [] - failed_pairs = [] - - try: - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - try: - evaluate_func = partial( - self.evaluate_single_pair, - model_handler=model_handler, - request=request, request_id=request_id - ) - - future_to_pair = { - executor.submit(evaluate_func, pair): pair - for pair in qa_pairs - } - - for future in as_completed(future_to_pair): - try: - result = future.result() - evaluated_pairs.append(result) - except ModelHandlerError: - raise - except Exception as e: - error_msg = f"Error processing future result: {str(e)}" - self.logger.error(error_msg) - failed_pairs.append({ - "error": error_msg, - "pair": future_to_pair[future] - }) - - except Exception as e: - error_msg = f"Error in parallel execution: {str(e)}" - self.logger.error(error_msg) - raise - - except ModelHandlerError: - raise - except Exception as e: - error_msg = f"Error in ThreadPoolExecutor setup: {str(e)}" - self.logger.error(error_msg) - raise - - try: - # Calculate statistics only from successful evaluations - scores = [pair["evaluation"]["score"] for pair in evaluated_pairs if pair.get("evaluation", {}).get("score") is not None] - - if scores: - average_score = sum(scores) / len(scores) - average_score = round(average_score, 2) - min_score = min(scores) - max_score = max(scores) - else: - average_score = min_score = max_score = 0 - - topic_stats = { - "average_score": average_score, - "min_score": min_score, - "max_score": max_score, - "evaluated_pairs": evaluated_pairs, - "failed_pairs": failed_pairs, - "total_evaluated": len(evaluated_pairs), - "total_failed": len(failed_pairs) - } - - self.logger.info(f"Completed evaluation for topic: {topic}. Average score: {topic_stats['average_score']:.2f}") - return topic_stats - - except Exception as e: - error_msg = f"Error calculating topic statistics: {str(e)}" - self.logger.error(error_msg) - return { - "average_score": 0, - "min_score": 0, - "max_score": 0, - "evaluated_pairs": evaluated_pairs, - "failed_pairs": failed_pairs, - "error": error_msg - } - except ModelHandlerError: - raise - except Exception as e: - error_msg = f"Critical error in evaluate_topic: {str(e)}" - self.logger.error(error_msg) - return { - "average_score": 0, - "min_score": 0, - "max_score": 0, - "evaluated_pairs": [], - "failed_pairs": [], - "error": error_msg - } - #@track_llm_operation("evaluate_results") - def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bool = True, request_id=None) -> Dict: - """Evaluate all QA pairs with parallel processing""" - try: - self.logger.info(f"Starting evaluation process - Demo Mode: {is_demo}") - - model_params = request.model_params or ModelParameters() - - self.logger.info(f"Creating model handler for model: {request.model_id}") - model_handler = create_handler( - request.model_id, - self.bedrock_client, - model_params=model_params, - inference_type = request.inference_type, - caii_endpoint = request.caii_endpoint - ) - - self.logger.info(f"Loading QA pairs from: {request.import_path}") - with open(request.import_path, 'r') as file: - data = json.load(file) - - evaluated_results = {} - all_scores = [] - - transformed_data = { - "results": {}, - } - for item in data: - topic = item.get('Seeds') - - # Create topic list if it doesn't exist - if topic not in transformed_data['results']: - transformed_data['results'][topic] = [] - - # Create QA pair - qa_pair = { - request.output_key: item.get(request.output_key, ''), # Use get() with default value - request.output_value: item.get(request.output_value, '') # Use get() with default value - } - - # Add to appropriate topic list - transformed_data['results'][topic].append(qa_pair) - - self.logger.info(f"Processing {len(transformed_data['results'])} topics with {self.max_workers} workers") - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_topic = { - executor.submit( - self.evaluate_topic, - topic, - qa_pairs, - model_handler, - request, request_id=request_id - ): topic - for topic, qa_pairs in transformed_data['results'].items() - } - - for future in as_completed(future_to_topic): - try: - topic = future_to_topic[future] - topic_stats = future.result() - evaluated_results[topic] = topic_stats - all_scores.extend([ - pair["evaluation"]["score"] - for pair in topic_stats["evaluated_pairs"] - ]) - except ModelHandlerError as e: - self.logger.error(f"ModelHandlerError in future processing: {str(e)}") - raise APIError(f"Model evaluation failed: {str(e)}") - - - overall_average = sum(all_scores) / len(all_scores) if all_scores else 0 - overall_average = round(overall_average, 2) - evaluated_results['Overall_Average'] = overall_average - - self.logger.info(f"Evaluation completed. Overall average score: {overall_average:.2f}") - - - timestamp = datetime.now(timezone.utc).isoformat() - time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] - model_name = get_model_family(request.model_id).split('.')[-1] - output_path = f"qa_pairs_{model_name}_{time_file}_evaluated.json" - - self.logger.info(f"Saving evaluation results to: {output_path}") - with open(output_path, 'w') as f: - json.dump(evaluated_results, f, indent=2) - - custom_prompt_str = PromptHandler.get_default_custom_eval_prompt( - request.use_case, - request.custom_prompt - ) - - - examples_value = ( - PromptHandler.get_default_eval_example(request.use_case, request.examples) - if hasattr(request, 'examples') - else None - ) - examples_str = self.safe_json_dumps(examples_value) - #print(examples_value, '\n',examples_str) - - metadata = { - 'timestamp': timestamp, - 'model_id': request.model_id, - 'inference_type': request.inference_type, - 'caii_endpoint':request.caii_endpoint, - 'use_case': request.use_case, - 'custom_prompt': custom_prompt_str, - 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, - 'generate_file_name': os.path.basename(request.import_path), - 'evaluate_file_name': os.path.basename(output_path), - 'display_name': request.display_name, - 'local_export_path': output_path, - 'examples': examples_str, - 'Overall_Average': overall_average - } - - self.logger.info("Saving evaluation metadata to database") - - - if is_demo: - self.db.save_evaluation_metadata(metadata) - return { - "status": "completed", - "result": evaluated_results, - "output_path": output_path - } - else: - - - job_status = "ENGINE_SUCCEEDED" - evaluate_file_name = os.path.basename(output_path) - self.db.update_job_evaluate(job_name, evaluate_file_name, output_path, timestamp, overall_average, job_status) - self.db.backup_and_restore_db() - return { - "status": "completed", - "output_path": output_path - } - except APIError: - raise - except ModelHandlerError as e: - # Add this specific handler - self.logger.error(f"ModelHandlerError in evaluation: {str(e)}") - raise APIError(str(e)) - except Exception as e: - error_msg = f"Error in evaluation process: {str(e)}" - self.logger.error(error_msg, exc_info=True) - if is_demo: - raise APIError(str(e)) - else: - time_stamp = datetime.now(timezone.utc).isoformat() - job_status = "ENGINE_FAILED" - file_name = '' - output_path = '' - overall_average = '' - self.db.update_job_evaluate(job_name,file_name, output_path, time_stamp, job_status) - - raise - def evaluate_single_row(self, row: Dict[str, Any], model_handler, request: EvaluationRequest, request_id = None) -> Dict: """Evaluate a single data row""" try: @@ -488,7 +143,8 @@ def evaluate_rows(self, rows: List[Dict[str, Any]], model_handler, request: Eval failed_rows = [] try: - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + max_workers = request.max_workers or self.max_workers + with ThreadPoolExecutor(max_workers=max_workers) as executor: try: evaluate_func = partial( self.evaluate_single_row, @@ -690,4 +346,4 @@ def evaluate_row_data(self, request: EvaluationRequest, job_name=None, is_demo: def safe_json_dumps(self, value): """Convert value to JSON string only if it's not None""" - return json.dumps(value) if value is not None else None \ No newline at end of file + return json.dumps(value) if value is not None else None diff --git a/app/services/model_alignment.py b/app/services/model_alignment.py index 650005a3..2b3fcfe1 100644 --- a/app/services/model_alignment.py +++ b/app/services/model_alignment.py @@ -7,8 +7,8 @@ import asyncio from datetime import datetime, timezone from typing import Dict, Optional -from app.services.synthesis_service import SynthesisService -from app.services.evaluator_service import EvaluatorService +from app.services.synthesis_legacy_service import SynthesisLegacyService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import SynthesisRequest, EvaluationRequest from app.models.request_models import ModelParameters from app.services.aws_bedrock import get_bedrock_client @@ -21,8 +21,8 @@ class ModelAlignment: """Service for aligning model outputs through synthesis and evaluation""" def __init__(self): - self.synthesis_service = SynthesisService() - self.evaluator_service = EvaluatorService() + self.synthesis_service = SynthesisLegacyService() + self.evaluator_service = EvaluatorLegacyService() self.db = DatabaseManager() self.bedrock_client = get_bedrock_client() # Add this line self._setup_logging() diff --git a/app/services/synthesis_job.py b/app/services/synthesis_job.py index 323fa7f0..8ea42010 100644 --- a/app/services/synthesis_job.py +++ b/app/services/synthesis_job.py @@ -3,9 +3,9 @@ import uuid import os from typing import Dict, Any, Optional -from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath -from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService from app.services.export_results import Export_Service from app.core.prompt_templates import PromptBuilder, PromptHandler from app.core.config import UseCase, USE_CASE_CONFIGS @@ -22,8 +22,8 @@ import cmlapi # Initialize services -synthesis_service = SynthesisService() -evaluator_service = EvaluatorService() +synthesis_service = SynthesisLegacyService() +evaluator_service = EvaluatorLegacyService() export_service = Export_Service() db_manager = DatabaseManager() diff --git a/app/services/synthesis_legacy_service.py b/app/services/synthesis_legacy_service.py new file mode 100644 index 00000000..62df587c --- /dev/null +++ b/app/services/synthesis_legacy_service.py @@ -0,0 +1,701 @@ +import boto3 +import json +import uuid +import time +import csv +from typing import List, Dict, Optional, Tuple +import uuid +from datetime import datetime, timezone +import os +from huggingface_hub import HfApi, HfFolder, Repository +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import math +import asyncio +from fastapi import FastAPI, BackgroundTasks, HTTPException +from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError, JSONParsingError +from app.core.data_loader import DataLoader +import pandas as pd +import numpy as np + +from app.models.request_models import SynthesisRequest, Example, ModelParameters +from app.core.model_handlers import create_handler +from app.core.prompt_templates import PromptBuilder, PromptHandler +from app.core.config import UseCase, Technique, get_model_family +from app.services.aws_bedrock import get_bedrock_client +from app.core.database import DatabaseManager +from app.services.check_guardrail import ContentGuardrail +from app.services.doc_extraction import DocumentProcessor +import logging +from logging.handlers import RotatingFileHandler +import traceback +from app.core.telemetry_integration import track_llm_operation +import uuid + + +class SynthesisLegacyService: + """Legacy service for generating synthetic QA pairs (SFT and Custom_Workflow only)""" + QUESTIONS_PER_BATCH = 5 # Maximum questions per batch + MAX_CONCURRENT_TOPICS = 5 # Default limit for concurrent I/O operations (configurable via request) + + + def __init__(self): + self.bedrock_client = get_bedrock_client() + self.db = DatabaseManager() + self._setup_logging() + self.guard = ContentGuardrail() + + + def _setup_logging(self): + """Set up logging configuration""" + os.makedirs('logs', exist_ok=True) + + self.logger = logging.getLogger('synthesis_legacy_service') + self.logger.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + # File handler for general logs + file_handler = RotatingFileHandler( + 'logs/synthesis_legacy_service.log', + maxBytes=10*1024*1024, # 10MB + backupCount=5 + ) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + # File handler for errors + error_handler = RotatingFileHandler( + 'logs/synthesis_legacy_service_errors.log', + maxBytes=10*1024*1024, + backupCount=5 + ) + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter(formatter) + self.logger.addHandler(error_handler) + + + #@track_llm_operation("process_single_topic") + def process_single_topic(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]: + """ + Process a single topic to generate questions and solutions. + Attempts batch processing first (default 5 questions), falls back to single question processing if batch fails. + + Args: + topic: The topic to generate questions for + model_handler: Handler for the AI model + request: The synthesis request object + num_questions: Total number of questions to generate + + Returns: + Tuple containing: + - topic (str) + - list of validated QA pairs + - list of error messages + - list of output dictionaries with topic information + + Raises: + ModelHandlerError: When there's an error in model generation that should stop processing + """ + topic_results = [] + topic_output = [] + topic_errors = [] + questions_remaining = num_questions + omit_questions = [] + + try: + # Process questions in batches + for batch_idx in range(0, num_questions, self.QUESTIONS_PER_BATCH): + if questions_remaining <= 0: + break + + batch_size = min(self.QUESTIONS_PER_BATCH, questions_remaining) + self.logger.info(f"Processing topic: {topic}, attempting batch {batch_idx+1}-{batch_idx+batch_size}") + + try: + # Attempt batch processing + prompt = PromptBuilder.build_prompt( + model_id=request.model_id, + use_case=request.use_case, + topic=topic, + num_questions=batch_size, + omit_questions=omit_questions, + examples=request.examples or [], + technique=request.technique, + schema=request.schema, + custom_prompt=request.custom_prompt, + ) + # print("prompt :", prompt) + batch_qa_pairs = None + try: + batch_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.warning(f"Batch processing failed: {str(e)}") + if isinstance(e, JSONParsingError): + # For JSON parsing errors, fall back to single processing + self.logger.info("JSON parsing failed, falling back to single processing") + continue + else: + # For other model errors, propagate up + raise + + if batch_qa_pairs: + # Process batch results + valid_pairs = [] + valid_outputs = [] + invalid_count = 0 + + for pair in batch_qa_pairs: + if self._validate_qa_pair(pair): + valid_pairs.append({ + "question": pair["question"], + "solution": pair["solution"] + }) + valid_outputs.append({ + "Topic": topic, + "question": pair["question"], + "solution": pair["solution"] + }) + omit_questions.append(pair["question"]) + #else: + invalid_count = batch_size - len(valid_pairs) + + if valid_pairs: + topic_results.extend(valid_pairs) + topic_output.extend(valid_outputs) + questions_remaining -= len(valid_pairs) + omit_questions = omit_questions[-100:] # Keep last 100 questions + self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}") + print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) + # If all pairs were valid, skip fallback + if invalid_count <= 0: + continue + + else: + # Fall back to single processing for remaining or failed questions + self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}") + remaining_batch = invalid_count + print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) + for _ in range(remaining_batch): + if questions_remaining <= 0: + break + + try: + # Single question processing + prompt = PromptBuilder.build_prompt( + model_id=request.model_id, + use_case=request.use_case, + topic=topic, + num_questions=1, + omit_questions=omit_questions, + examples=request.examples or [], + technique=request.technique, + schema=request.schema, + custom_prompt=request.custom_prompt, + ) + + try: + single_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.warning(f"Batch processing failed: {str(e)}") + if isinstance(e, JSONParsingError): + # For JSON parsing errors, fall back to single processing + self.logger.info("JSON parsing failed, falling back to single processing") + continue + else: + # For other model errors, propagate up + raise + + if single_qa_pairs and len(single_qa_pairs) > 0: + pair = single_qa_pairs[0] + if self._validate_qa_pair(pair): + validated_pair = { + "question": pair["question"], + "solution": pair["solution"] + } + validated_output = { + "Topic": topic, + "question": pair["question"], + "solution": pair["solution"] + } + + topic_results.append(validated_pair) + topic_output.append(validated_output) + omit_questions.append(pair["question"]) + omit_questions = omit_questions[-100:] + questions_remaining -= 1 + + self.logger.info(f"Successfully generated single question for topic {topic}") + else: + error_msg = f"Invalid QA pair structure in single processing for topic {topic}" + self.logger.warning(error_msg) + topic_errors.append(error_msg) + else: + error_msg = f"No QA pair generated in single processing for topic {topic}" + self.logger.warning(error_msg) + topic_errors.append(error_msg) + + except ModelHandlerError as e: + # Don't raise - add to errors and continue + error_msg = f"ModelHandlerError in single processing for topic {topic}: {str(e)}" + self.logger.error(error_msg) + topic_errors.append(error_msg) + continue + + except ModelHandlerError: + # Re-raise ModelHandlerError to propagate up + raise + except Exception as e: + error_msg = f"Error processing batch for topic {topic}: {str(e)}" + self.logger.error(error_msg) + topic_errors.append(error_msg) + continue + + except ModelHandlerError: + # Re-raise ModelHandlerError to propagate up + raise + except Exception as e: + error_msg = f"Critical error processing topic {topic}: {str(e)}" + self.logger.error(error_msg) + topic_errors.append(error_msg) + + return topic, topic_results, topic_errors, topic_output + + + async def generate_examples(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id= None) -> Dict: + """Generate examples based on request parameters (SFT technique)""" + try: + output_key = request.output_key + output_value = request.output_value + st = time.time() + self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") + + # Use default parameters if none provided + model_params = request.model_params or ModelParameters() + + # Create model handler + self.logger.info("Creating model handler") + model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint) + + # Limit topics and questions in demo mode + if request.doc_paths: + processor = DocumentProcessor(chunk_size=1000, overlap=100) + paths = request.doc_paths + topics = [] + for path in paths: + chunks = processor.process_document(path) + topics.extend(chunks) + #topics = topics[0:5] + print("total chunks: ", len(topics)) + if request.num_questions<=len(topics): + topics = topics[0:request.num_questions] + num_questions = 1 + print("num_questions :", num_questions) + else: + num_questions = math.ceil(request.num_questions/len(topics)) + #print(num_questions) + total_count = request.num_questions + else: + if request.topics: + topics = request.topics + num_questions = request.num_questions + total_count = request.num_questions*len(request.topics) + + else: + self.logger.error("Generation failed: No topics provided") + raise RuntimeError("Invalid input: No topics provided") + + + # Track results for each topic + results = {} + all_errors = [] + final_output = [] + + # Create thread pool + loop = asyncio.get_event_loop() + max_workers = request.max_concurrent_topics or self.MAX_CONCURRENT_TOPICS + with ThreadPoolExecutor(max_workers=max_workers) as executor: + topic_futures = [ + loop.run_in_executor( + executor, + self.process_single_topic, + topic, + model_handler, + request, + num_questions, + request_id + ) + for topic in topics + ] + + # Wait for all futures to complete + try: + completed_topics = await asyncio.gather(*topic_futures) + except ModelHandlerError as e: + self.logger.error(f"Model generation failed: {str(e)}") + raise APIError(f"Failed to generate content: {str(e)}") + + # Process results + + for topic, topic_results, topic_errors, topic_output in completed_topics: + if topic_errors: + all_errors.extend(topic_errors) + if topic_results and is_demo: + results[topic] = topic_results + if topic_output: + final_output.extend(topic_output) + + generation_time = time.time() - st + self.logger.info(f"Generation completed in {generation_time:.2f} seconds") + + timestamp = datetime.now(timezone.utc).isoformat() + time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] + mode_suffix = "test" if is_demo else "final" + model_name = get_model_family(request.model_id).split('.')[-1] + file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" + if request.doc_paths: + final_output = [{ + 'Generated_From': item['Topic'], + output_key: item['question'], + output_value: item['solution'] } + for item in final_output] + else: + final_output = [{ + 'Seeds': item['Topic'], + output_key: item['question'], + output_value: item['solution'] } + for item in final_output] + output_path = {} + try: + with open(file_path, "w") as f: + json.dump(final_output, indent=2, fp=f) + except Exception as e: + self.logger.error(f"Error saving results: {str(e)}", exc_info=True) + + output_path['local']= file_path + + + + + # Handle custom prompt, examples and schema + custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) + # For examples + examples_value = ( + PromptHandler.get_default_example(request.use_case, request.examples) + if hasattr(request, 'examples') + else None + ) + examples_str = self.safe_json_dumps(examples_value) + + # For schema + schema_value = ( + PromptHandler.get_default_schema(request.use_case, request.schema) + if hasattr(request, 'schema') + else None + ) + schema_str = self.safe_json_dumps(schema_value) + + # For topics + topics_value = topics if not getattr(request, 'doc_paths', None) else None + topic_str = self.safe_json_dumps(topics_value) + + # For doc_paths and input_path + doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) + input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) + + metadata = { + 'timestamp': timestamp, + 'technique': request.technique, + 'model_id': request.model_id, + 'inference_type': request.inference_type, + 'caii_endpoint':request.caii_endpoint, + 'use_case': request.use_case, + 'final_prompt': custom_prompt_str, + 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, + 'generate_file_name': os.path.basename(output_path['local']), + 'display_name': request.display_name, + 'output_path': output_path, + 'num_questions':getattr(request, 'num_questions', None), + 'topics': topic_str, + 'examples': examples_str, + "total_count":total_count, + 'schema': schema_str, + 'doc_paths': doc_paths_str, + 'input_path':input_path_str, + 'input_key': request.input_key, + 'output_key':request.output_key, + 'output_value':request.output_value + } + + #print("metadata: ",metadata) + if is_demo: + + self.db.save_generation_metadata(metadata) + return { + "status": "completed" if results else "failed", + "results": results, + "errors": all_errors if all_errors else None, + "export_path": output_path + } + else: + # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) + # time_stamp = extract_timestamp(metadata.get('generate_file_name')) + job_status = "ENGINE_SUCCEEDED" + generate_file_name = os.path.basename(output_path['local']) + + self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) + self.db.backup_and_restore_db() + return { + "status": "completed" if final_output else "failed", + "export_path": output_path + } + except APIError: + raise + + except Exception as e: + self.logger.error(f"Generation failed: {str(e)}", exc_info=True) + if is_demo: + raise APIError(str(e)) # Let middleware decide status code + else: + time_stamp = datetime.now(timezone.utc).isoformat() + job_status = "ENGINE_FAILED" + file_name = '' + output_path = '' + self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) + raise # Just re-raise the original exception + + + def _validate_qa_pair(self, pair: Dict) -> bool: + """Validate a question-answer pair""" + return ( + isinstance(pair, dict) and + "question" in pair and + "solution" in pair and + isinstance(pair["question"], str) and + isinstance(pair["solution"], str) and + len(pair["question"].strip()) > 0 and + len(pair["solution"].strip()) > 0 + ) + + #@track_llm_operation("process_single_input") + async def process_single_input(self, input, model_handler, request, request_id=None): + try: + prompt = PromptBuilder.build_generate_result_prompt( + model_id=request.model_id, + use_case=request.use_case, + input=input, + examples=request.examples or [], + schema=request.schema, + custom_prompt=request.custom_prompt, + ) + try: + result = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") + raise + + return {"question": input, "solution": result} + + except ModelHandlerError: + raise + except Exception as e: + self.logger.error(f"Error processing input: {str(e)}") + raise APIError(f"Failed to process input: {str(e)}") + + async def generate_result(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id=None) -> Dict: + """Generate results based on request parameters (Custom_Workflow technique)""" + try: + self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") + + + # Use default parameters if none provided + model_params = request.model_params or ModelParameters() + + # Create model handler + self.logger.info("Creating model handler") + model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint, custom_p = True) + + inputs = [] + file_paths = request.input_path + for path in file_paths: + try: + with open(path) as f: + data = json.load(f) + inputs.extend(item.get(request.input_key, '') for item in data) + except Exception as e: + print(f"Error processing {path}: {str(e)}") + MAX_WORKERS = 5 + + + # Create thread pool + loop = asyncio.get_event_loop() + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + # Create futures for each input + input_futures = [ + loop.run_in_executor( + executor, + lambda x: asyncio.run(self.process_single_input(x, model_handler, request, request_id)), + input + ) + for input in inputs + ] + + # Wait for all futures to complete + try: + final_output = await asyncio.gather(*input_futures) + except ModelHandlerError as e: + self.logger.error(f"Model generation failed: {str(e)}") + raise APIError(f"Failed to generate content: {str(e)}") + + + + + timestamp = datetime.now(timezone.utc).isoformat() + time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] + mode_suffix = "test" if is_demo else "final" + model_name = get_model_family(request.model_id).split('.')[-1] + file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" + input_key = request.output_key or request.input_key + result = [{ + + input_key: item['question'], + request.output_value: item['solution'] } + for item in final_output] + output_path = {} + try: + with open(file_path, "w") as f: + json.dump(result, indent=2, fp=f) + except Exception as e: + self.logger.error(f"Error saving results: {str(e)}", exc_info=True) + + + + + output_path['local']= file_path + + + # Handle custom prompt, examples and schema + custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) + # For examples + examples_value = ( + PromptHandler.get_default_example(request.use_case, request.examples) + if hasattr(request, 'examples') + else None + ) + examples_str = self.safe_json_dumps(examples_value) + + # For schema + schema_value = ( + PromptHandler.get_default_schema(request.use_case, request.schema) + if hasattr(request, 'schema') + else None + ) + schema_str = self.safe_json_dumps(schema_value) + + # For topics + topics_value = None + topic_str = self.safe_json_dumps(topics_value) + + # For doc_paths and input_path + doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) + input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) + + + + metadata = { + 'timestamp': timestamp, + 'technique': request.technique, + 'model_id': request.model_id, + 'inference_type': request.inference_type, + 'caii_endpoint':request.caii_endpoint, + 'use_case': request.use_case, + 'final_prompt': custom_prompt_str, + 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, + 'generate_file_name': os.path.basename(output_path['local']), + 'display_name': request.display_name, + 'output_path': output_path, + 'num_questions':getattr(request, 'num_questions', None), + 'topics': topic_str, + 'examples': examples_str, + "total_count":len(inputs), + 'schema': schema_str, + 'doc_paths': doc_paths_str, + 'input_path':input_path_str, + 'input_key': request.input_key, + 'output_key':request.output_key, + 'output_value':request.output_value + } + + + if is_demo: + + self.db.save_generation_metadata(metadata) + return { + "status": "completed" if final_output else "failed", + "results": final_output, + "export_path": output_path + } + else: + # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) + # time_stamp = extract_timestamp(metadata.get('generate_file_name')) + job_status = "success" + generate_file_name = os.path.basename(output_path['local']) + + self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) + self.db.backup_and_restore_db() + return { + "status": "completed" if final_output else "failed", + "export_path": output_path + } + + except APIError: + raise + except Exception as e: + self.logger.error(f"Generation failed: {str(e)}", exc_info=True) + if is_demo: + raise APIError(str(e)) # Let middleware decide status code + else: + time_stamp = datetime.now(timezone.utc).isoformat() + job_status = "failure" + file_name = '' + output_path = '' + self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) + raise # Just re-raise the original exception + + def get_health_check(self) -> Dict: + """Get service health status""" + try: + test_body = { + "prompt": "\n\nHuman: test\n\nAssistant: ", + "max_tokens_to_sample": 1, + "temperature": 0 + } + + self.bedrock_client.invoke_model( + modelId="anthropic.claude-instant-v1", + body=json.dumps(test_body) + ) + + status = { + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "service": "SynthesisLegacyService", + "aws_region": self.bedrock_client.meta.region_name + } + self.logger.info("Health check passed", extra=status) + return status + + except Exception as e: + status = { + "status": "unhealthy", + "error": str(e), + "timestamp": datetime.now().isoformat(), + "service": "SynthesisLegacyService", + "aws_region": self.bedrock_client.meta.region_name + } + self.logger.error("Health check failed", extra=status, exc_info=True) + return status + + def safe_json_dumps(self, value): + """Convert value to JSON string only if it's not None""" + return json.dumps(value) if value is not None else None diff --git a/app/services/synthesis_service.py b/app/services/synthesis_service.py index e42ac3c8..6f8e643e 100644 --- a/app/services/synthesis_service.py +++ b/app/services/synthesis_service.py @@ -34,19 +34,13 @@ class SynthesisService: - """Service for generating synthetic QA pairs""" + """Service for generating synthetic freeform data (Freeform technique only)""" QUESTIONS_PER_BATCH = 5 # Maximum questions per batch - MAX_CONCURRENT_TOPICS = 5 # Limit concurrent I/O operations + MAX_CONCURRENT_TOPICS = 5 # Default limit for concurrent I/O operations (configurable via request) def __init__(self): - # self.bedrock_client = boto3.Session(profile_name='cu_manowar_dev').client( - # 'bedrock-runtime', - # region_name='us-west-2' - # ) - self.bedrock_client = get_bedrock_client() - - + self.bedrock_client = get_bedrock_client() self.db = DatabaseManager() self._setup_logging() self.guard = ContentGuardrail() @@ -80,592 +74,6 @@ def _setup_logging(self): error_handler.setFormatter(formatter) self.logger.addHandler(error_handler) - - #@track_llm_operation("process_single_topic") - def process_single_topic(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]: - """ - Process a single topic to generate questions and solutions. - Attempts batch processing first (default 5 questions), falls back to single question processing if batch fails. - - Args: - topic: The topic to generate questions for - model_handler: Handler for the AI model - request: The synthesis request object - num_questions: Total number of questions to generate - - Returns: - Tuple containing: - - topic (str) - - list of validated QA pairs - - list of error messages - - list of output dictionaries with topic information - - Raises: - ModelHandlerError: When there's an error in model generation that should stop processing - """ - topic_results = [] - topic_output = [] - topic_errors = [] - questions_remaining = num_questions - omit_questions = [] - - try: - # Process questions in batches - for batch_idx in range(0, num_questions, self.QUESTIONS_PER_BATCH): - if questions_remaining <= 0: - break - - batch_size = min(self.QUESTIONS_PER_BATCH, questions_remaining) - self.logger.info(f"Processing topic: {topic}, attempting batch {batch_idx+1}-{batch_idx+batch_size}") - - try: - # Attempt batch processing - prompt = PromptBuilder.build_prompt( - model_id=request.model_id, - use_case=request.use_case, - topic=topic, - num_questions=batch_size, - omit_questions=omit_questions, - examples=request.examples or [], - technique=request.technique, - schema=request.schema, - custom_prompt=request.custom_prompt, - ) - # print("prompt :", prompt) - batch_qa_pairs = None - try: - batch_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.warning(f"Batch processing failed: {str(e)}") - if isinstance(e, JSONParsingError): - # For JSON parsing errors, fall back to single processing - self.logger.info("JSON parsing failed, falling back to single processing") - continue - else: - # For other model errors, propagate up - raise - - if batch_qa_pairs: - # Process batch results - valid_pairs = [] - valid_outputs = [] - invalid_count = 0 - - for pair in batch_qa_pairs: - if self._validate_qa_pair(pair): - valid_pairs.append({ - "question": pair["question"], - "solution": pair["solution"] - }) - valid_outputs.append({ - "Topic": topic, - "question": pair["question"], - "solution": pair["solution"] - }) - omit_questions.append(pair["question"]) - #else: - invalid_count = batch_size - len(valid_pairs) - - if valid_pairs: - topic_results.extend(valid_pairs) - topic_output.extend(valid_outputs) - questions_remaining -= len(valid_pairs) - omit_questions = omit_questions[-100:] # Keep last 100 questions - self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}") - print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) - # If all pairs were valid, skip fallback - if invalid_count <= 0: - continue - - else: - # Fall back to single processing for remaining or failed questions - self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}") - remaining_batch = invalid_count - print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) - for _ in range(remaining_batch): - if questions_remaining <= 0: - break - - try: - # Single question processing - prompt = PromptBuilder.build_prompt( - model_id=request.model_id, - use_case=request.use_case, - topic=topic, - num_questions=1, - omit_questions=omit_questions, - examples=request.examples or [], - technique=request.technique, - schema=request.schema, - custom_prompt=request.custom_prompt, - ) - - try: - single_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.warning(f"Batch processing failed: {str(e)}") - if isinstance(e, JSONParsingError): - # For JSON parsing errors, fall back to single processing - self.logger.info("JSON parsing failed, falling back to single processing") - continue - else: - # For other model errors, propagate up - raise - - if single_qa_pairs and len(single_qa_pairs) > 0: - pair = single_qa_pairs[0] - if self._validate_qa_pair(pair): - validated_pair = { - "question": pair["question"], - "solution": pair["solution"] - } - validated_output = { - "Topic": topic, - "question": pair["question"], - "solution": pair["solution"] - } - - topic_results.append(validated_pair) - topic_output.append(validated_output) - omit_questions.append(pair["question"]) - omit_questions = omit_questions[-100:] - questions_remaining -= 1 - - self.logger.info(f"Successfully generated single question for topic {topic}") - else: - error_msg = f"Invalid QA pair structure in single processing for topic {topic}" - self.logger.warning(error_msg) - topic_errors.append(error_msg) - else: - error_msg = f"No QA pair generated in single processing for topic {topic}" - self.logger.warning(error_msg) - topic_errors.append(error_msg) - - except ModelHandlerError as e: - # Don't raise - add to errors and continue - error_msg = f"ModelHandlerError in single processing for topic {topic}: {str(e)}" - self.logger.error(error_msg) - topic_errors.append(error_msg) - continue - - except ModelHandlerError: - # Re-raise ModelHandlerError to propagate up - raise - except Exception as e: - error_msg = f"Error processing batch for topic {topic}: {str(e)}" - self.logger.error(error_msg) - topic_errors.append(error_msg) - continue - - except ModelHandlerError: - # Re-raise ModelHandlerError to propagate up - raise - except Exception as e: - error_msg = f"Critical error processing topic {topic}: {str(e)}" - self.logger.error(error_msg) - topic_errors.append(error_msg) - - return topic, topic_results, topic_errors, topic_output - - - async def generate_examples(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id= None) -> Dict: - """Generate examples based on request parameters""" - try: - output_key = request.output_key - output_value = request.output_value - st = time.time() - self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") - - # Use default parameters if none provided - model_params = request.model_params or ModelParameters() - - # Create model handler - self.logger.info("Creating model handler") - model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint) - - # Limit topics and questions in demo mode - if request.doc_paths: - processor = DocumentProcessor(chunk_size=1000, overlap=100) - paths = request.doc_paths - topics = [] - for path in paths: - chunks = processor.process_document(path) - topics.extend(chunks) - #topics = topics[0:5] - print("total chunks: ", len(topics)) - if request.num_questions<=len(topics): - topics = topics[0:request.num_questions] - num_questions = 1 - print("num_questions :", num_questions) - else: - num_questions = math.ceil(request.num_questions/len(topics)) - #print(num_questions) - total_count = request.num_questions - else: - if request.topics: - topics = request.topics - num_questions = request.num_questions - total_count = request.num_questions*len(request.topics) - - else: - self.logger.error("Generation failed: No topics provided") - raise RuntimeError("Invalid input: No topics provided") - - - # Track results for each topic - results = {} - all_errors = [] - final_output = [] - - # Create thread pool - loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor: - topic_futures = [ - loop.run_in_executor( - executor, - self.process_single_topic, - topic, - model_handler, - request, - num_questions, - request_id - ) - for topic in topics - ] - - # Wait for all futures to complete - try: - completed_topics = await asyncio.gather(*topic_futures) - except ModelHandlerError as e: - self.logger.error(f"Model generation failed: {str(e)}") - raise APIError(f"Failed to generate content: {str(e)}") - - # Process results - - for topic, topic_results, topic_errors, topic_output in completed_topics: - if topic_errors: - all_errors.extend(topic_errors) - if topic_results and is_demo: - results[topic] = topic_results - if topic_output: - final_output.extend(topic_output) - - generation_time = time.time() - st - self.logger.info(f"Generation completed in {generation_time:.2f} seconds") - - timestamp = datetime.now(timezone.utc).isoformat() - time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] - mode_suffix = "test" if is_demo else "final" - model_name = get_model_family(request.model_id).split('.')[-1] - file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" - if request.doc_paths: - final_output = [{ - 'Generated_From': item['Topic'], - output_key: item['question'], - output_value: item['solution'] } - for item in final_output] - else: - final_output = [{ - 'Seeds': item['Topic'], - output_key: item['question'], - output_value: item['solution'] } - for item in final_output] - output_path = {} - try: - with open(file_path, "w") as f: - json.dump(final_output, indent=2, fp=f) - except Exception as e: - self.logger.error(f"Error saving results: {str(e)}", exc_info=True) - - output_path['local']= file_path - - - - - # Handle custom prompt, examples and schema - custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) - # For examples - examples_value = ( - PromptHandler.get_default_example(request.use_case, request.examples) - if hasattr(request, 'examples') - else None - ) - examples_str = self.safe_json_dumps(examples_value) - - # For schema - schema_value = ( - PromptHandler.get_default_schema(request.use_case, request.schema) - if hasattr(request, 'schema') - else None - ) - schema_str = self.safe_json_dumps(schema_value) - - # For topics - topics_value = topics if not getattr(request, 'doc_paths', None) else None - topic_str = self.safe_json_dumps(topics_value) - - # For doc_paths and input_path - doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) - input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) - - metadata = { - 'timestamp': timestamp, - 'technique': request.technique, - 'model_id': request.model_id, - 'inference_type': request.inference_type, - 'caii_endpoint':request.caii_endpoint, - 'use_case': request.use_case, - 'final_prompt': custom_prompt_str, - 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, - 'generate_file_name': os.path.basename(output_path['local']), - 'display_name': request.display_name, - 'output_path': output_path, - 'num_questions':getattr(request, 'num_questions', None), - 'topics': topic_str, - 'examples': examples_str, - "total_count":total_count, - 'schema': schema_str, - 'doc_paths': doc_paths_str, - 'input_path':input_path_str, - 'input_key': request.input_key, - 'output_key':request.output_key, - 'output_value':request.output_value - } - - #print("metadata: ",metadata) - if is_demo: - - self.db.save_generation_metadata(metadata) - return { - "status": "completed" if results else "failed", - "results": results, - "errors": all_errors if all_errors else None, - "export_path": output_path - } - else: - # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) - # time_stamp = extract_timestamp(metadata.get('generate_file_name')) - job_status = "ENGINE_SUCCEEDED" - generate_file_name = os.path.basename(output_path['local']) - - self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) - self.db.backup_and_restore_db() - return { - "status": "completed" if final_output else "failed", - "export_path": output_path - } - except APIError: - raise - - except Exception as e: - self.logger.error(f"Generation failed: {str(e)}", exc_info=True) - if is_demo: - raise APIError(str(e)) # Let middleware decide status code - else: - time_stamp = datetime.now(timezone.utc).isoformat() - job_status = "ENGINE_FAILED" - file_name = '' - output_path = '' - self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) - raise # Just re-raise the original exception - - - def _validate_qa_pair(self, pair: Dict) -> bool: - """Validate a question-answer pair""" - return ( - isinstance(pair, dict) and - "question" in pair and - "solution" in pair and - isinstance(pair["question"], str) and - isinstance(pair["solution"], str) and - len(pair["question"].strip()) > 0 and - len(pair["solution"].strip()) > 0 - ) - #@track_llm_operation("process_single_input") - async def process_single_input(self, input, model_handler, request, request_id=None): - try: - prompt = PromptBuilder.build_generate_result_prompt( - model_id=request.model_id, - use_case=request.use_case, - input=input, - examples=request.examples or [], - schema=request.schema, - custom_prompt=request.custom_prompt, - ) - try: - result = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") - raise - - return {"question": input, "solution": result} - - except ModelHandlerError: - raise - except Exception as e: - self.logger.error(f"Error processing input: {str(e)}") - raise APIError(f"Failed to process input: {str(e)}") - - async def generate_result(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id=None) -> Dict: - try: - self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") - - - # Use default parameters if none provided - model_params = request.model_params or ModelParameters() - - # Create model handler - self.logger.info("Creating model handler") - model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint, custom_p = True) - - inputs = [] - file_paths = request.input_path - for path in file_paths: - try: - with open(path) as f: - data = json.load(f) - inputs.extend(item.get(request.input_key, '') for item in data) - except Exception as e: - print(f"Error processing {path}: {str(e)}") - MAX_WORKERS = 5 - - - # Create thread pool - loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - # Create futures for each input - input_futures = [ - loop.run_in_executor( - executor, - lambda x: asyncio.run(self.process_single_input(x, model_handler, request, request_id)), - input - ) - for input in inputs - ] - - # Wait for all futures to complete - try: - final_output = await asyncio.gather(*input_futures) - except ModelHandlerError as e: - self.logger.error(f"Model generation failed: {str(e)}") - raise APIError(f"Failed to generate content: {str(e)}") - - - - - timestamp = datetime.now(timezone.utc).isoformat() - time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] - mode_suffix = "test" if is_demo else "final" - model_name = get_model_family(request.model_id).split('.')[-1] - file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" - input_key = request.output_key or request.input_key - result = [{ - - input_key: item['question'], - request.output_value: item['solution'] } - for item in final_output] - output_path = {} - try: - with open(file_path, "w") as f: - json.dump(result, indent=2, fp=f) - except Exception as e: - self.logger.error(f"Error saving results: {str(e)}", exc_info=True) - - - - - - output_path['local']= file_path - - - # Handle custom prompt, examples and schema - custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) - # For examples - examples_value = ( - PromptHandler.get_default_example(request.use_case, request.examples) - if hasattr(request, 'examples') - else None - ) - examples_str = self.safe_json_dumps(examples_value) - - # For schema - schema_value = ( - PromptHandler.get_default_schema(request.use_case, request.schema) - if hasattr(request, 'schema') - else None - ) - schema_str = self.safe_json_dumps(schema_value) - - # For topics - topics_value = None - topic_str = self.safe_json_dumps(topics_value) - - # For doc_paths and input_path - doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) - input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) - - - - metadata = { - 'timestamp': timestamp, - 'technique': request.technique, - 'model_id': request.model_id, - 'inference_type': request.inference_type, - 'caii_endpoint':request.caii_endpoint, - 'use_case': request.use_case, - 'final_prompt': custom_prompt_str, - 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, - 'generate_file_name': os.path.basename(output_path['local']), - 'display_name': request.display_name, - 'output_path': output_path, - 'num_questions':getattr(request, 'num_questions', None), - 'topics': topic_str, - 'examples': examples_str, - "total_count":len(inputs), - 'schema': schema_str, - 'doc_paths': doc_paths_str, - 'input_path':input_path_str, - 'input_key': request.input_key, - 'output_key':request.output_key, - 'output_value':request.output_value - } - - - if is_demo: - - self.db.save_generation_metadata(metadata) - return { - "status": "completed" if final_output else "failed", - "results": final_output, - "export_path": output_path - } - else: - # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) - # time_stamp = extract_timestamp(metadata.get('generate_file_name')) - job_status = "success" - generate_file_name = os.path.basename(output_path['local']) - - self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) - self.db.backup_and_restore_db() - return { - "status": "completed" if final_output else "failed", - "export_path": output_path - } - - except APIError: - raise - except Exception as e: - self.logger.error(f"Generation failed: {str(e)}", exc_info=True) - if is_demo: - raise APIError(str(e)) # Let middleware decide status code - else: - time_stamp = datetime.now(timezone.utc).isoformat() - job_status = "failure" - file_name = '' - output_path = '' - self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) - raise # Just re-raise the original exception - #@track_llm_operation("process_single_freeform") def process_single_freeform(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]: """ @@ -960,7 +368,8 @@ async def generate_freeform(self, request: SynthesisRequest, job_name=None, is_d # Create thread pool loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor: + max_workers = request.max_concurrent_topics or self.MAX_CONCURRENT_TOPICS + with ThreadPoolExecutor(max_workers=max_workers) as executor: topic_futures = [ loop.run_in_executor( executor, @@ -1209,4 +618,4 @@ def get_health_check(self) -> Dict: def safe_json_dumps(self, value): """Convert value to JSON string only if it's not None""" - return json.dumps(value) if value is not None else None \ No newline at end of file + return json.dumps(value) if value is not None else None diff --git a/tests/integration/test_evaluate_api.py b/tests/integration/test_evaluate_api.py index a842cd55..2ff664a9 100644 --- a/tests/integration/test_evaluate_api.py +++ b/tests/integration/test_evaluate_api.py @@ -3,7 +3,7 @@ import json from fastapi.testclient import TestClient from pathlib import Path -from app.main import app, db_manager, evaluator_service # global instance created on import +from app.main import app, db_manager, evaluator_legacy_service # global instance created on import client = TestClient(app) # Create a dummy bedrock client that simulates the Converse/invoke_model response. @@ -37,8 +37,8 @@ def mock_qa_file(tmp_path, mock_qa_data): # Patch the global evaluator_service's AWS client before tests run. @pytest.fixture(autouse=True) def patch_evaluator_bedrock_client(): - from app.main import evaluator_service - evaluator_service.bedrock_client = DummyBedrockClient() + from app.main import evaluator_legacy_service + evaluator_legacy_service.bedrock_client = DummyBedrockClient() def test_evaluate_endpoint(mock_qa_file): request_data = { @@ -51,7 +51,7 @@ def test_evaluate_endpoint(mock_qa_file): "output_value": "Completion" } # Optionally, patch create_handler to return a dummy handler that returns a dummy evaluation. - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] response = client.post("/synthesis/evaluate", json=request_data) # In demo mode, our endpoint returns a dict with "status", "result", and "output_path". @@ -71,7 +71,7 @@ def test_job_handling(mock_qa_file): "output_key": "Prompt", "output_value": "Completion" } - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] response = client.post("/synthesis/evaluate", json=request_data) assert response.status_code == 200 diff --git a/tests/integration/test_evaluate_legacy_api.py b/tests/integration/test_evaluate_legacy_api.py new file mode 100644 index 00000000..2ff664a9 --- /dev/null +++ b/tests/integration/test_evaluate_legacy_api.py @@ -0,0 +1,113 @@ +import pytest +from unittest.mock import patch, Mock +import json +from fastapi.testclient import TestClient +from pathlib import Path +from app.main import app, db_manager, evaluator_legacy_service # global instance created on import +client = TestClient(app) + +# Create a dummy bedrock client that simulates the Converse/invoke_model response. +class DummyBedrockClient: + def invoke_model(self, modelId, body): + # Return a dummy response structure (adjust if your handler expects a different format) + return [{ + "score": 1.0, + "justification": "Dummy response from invoke_model" + }] + @property + def meta(self): + class Meta: + region_name = "us-west-2" + return Meta() + +@pytest.fixture +def mock_qa_data(): + return [ + {"Seeds": "python_basics", "Prompt": "What is Python?", "Completion": "Python is a programming language"}, + {"Seeds": "python_basics", "Prompt": "How do you define a function?", "Completion": "Use the def keyword followed by function name"} + ] + +@pytest.fixture +def mock_qa_file(tmp_path, mock_qa_data): + file_path = tmp_path / "qa_pairs.json" + with open(file_path, "w") as f: + json.dump(mock_qa_data, f) + return str(file_path) + +# Patch the global evaluator_service's AWS client before tests run. +@pytest.fixture(autouse=True) +def patch_evaluator_bedrock_client(): + from app.main import evaluator_legacy_service + evaluator_legacy_service.bedrock_client = DummyBedrockClient() + +def test_evaluate_endpoint(mock_qa_file): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "import_path": mock_qa_file, + "is_demo": True, + "output_key": "Prompt", + "output_value": "Completion" + } + # Optionally, patch create_handler to return a dummy handler that returns a dummy evaluation. + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] + response = client.post("/synthesis/evaluate", json=request_data) + # In demo mode, our endpoint returns a dict with "status", "result", and "output_path". + assert response.status_code == 200 + res_json = response.json() + assert res_json["status"] == "completed" + assert "output_path" in res_json + assert "result" in res_json + +def test_job_handling(mock_qa_file): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "import_path": mock_qa_file, + "is_demo": True, + "output_key": "Prompt", + "output_value": "Completion" + } + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] + response = client.post("/synthesis/evaluate", json=request_data) + assert response.status_code == 200 + res_json = response.json() + # In demo mode, we don't expect a "job_id" key; we check for "output_path" and "result". + assert "output_path" in res_json + # Now simulate history by patching db_manager.get_all_evaluate_metadata + db_manager.get_all_evaluate_metadata = lambda: [{"evaluate_file_name": "test.json", "average_score": 0.9}] + response = client.get("/evaluations/history") + assert response.status_code == 200 + history = response.json() + assert len(history) > 0 + +def test_evaluate_with_invalid_model(mock_qa_file): + request_data = { + "use_case": "custom", + "model_id": "invalid.model", + "inference_type": "aws_bedrock", + "import_path": mock_qa_file, + "is_demo": True, + "output_key": "Prompt", + "output_value": "Completion" + } + + from app.core.exceptions import ModelHandlerError + + # Patch create_handler to raise ModelHandlerError + with patch('app.services.evaluator_service.create_handler') as mock_create: + mock_create.side_effect = ModelHandlerError("Invalid model identifier: invalid.model") + response = client.post("/synthesis/evaluate", json=request_data) + + # Print for debugging + print(f"Response status: {response.status_code}") + print(f"Response content: {response.json()}") + + # Expect a 400 or 500 error response + assert response.status_code in [400, 500] + res_json = response.json() + assert "error" in res_json diff --git a/tests/integration/test_synthesis_api.py b/tests/integration/test_synthesis_api.py index cbd60e9f..db54d9c7 100644 --- a/tests/integration/test_synthesis_api.py +++ b/tests/integration/test_synthesis_api.py @@ -14,7 +14,7 @@ def test_generate_endpoint_with_topics(): "topics": ["python_basics"], "is_demo": True } - with patch('app.main.SynthesisService.generate_examples') as mock_generate: + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: mock_generate.return_value = { "status": "completed", "export_path": {"local": "test.json"}, @@ -35,7 +35,7 @@ def test_generate_endpoint_with_doc_paths(): "doc_paths": ["test.pdf"], "is_demo": True } - with patch('app.main.SynthesisService.generate_examples') as mock_generate: + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: mock_generate.return_value = { "status": "completed", "export_path": {"local": "test.json"}, diff --git a/tests/integration/test_synthesis_legacy_api.py b/tests/integration/test_synthesis_legacy_api.py new file mode 100644 index 00000000..db54d9c7 --- /dev/null +++ b/tests/integration/test_synthesis_legacy_api.py @@ -0,0 +1,91 @@ +import pytest +from unittest.mock import patch +import json +from fastapi.testclient import TestClient +from app.main import app, db_manager +client = TestClient(app) + +def test_generate_endpoint_with_topics(): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "num_questions": 2, + "topics": ["python_basics"], + "is_demo": True + } + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: + mock_generate.return_value = { + "status": "completed", + "export_path": {"local": "test.json"}, + "results": {"python_basics": [{"question": "test?", "solution": "test!"}]} + } + response = client.post("/synthesis/generate", json=request_data) + assert response.status_code == 200 + res_json = response.json() + assert res_json.get("status") == "completed" + assert "export_path" in res_json + +def test_generate_endpoint_with_doc_paths(): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "num_questions": 2, + "doc_paths": ["test.pdf"], + "is_demo": True + } + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: + mock_generate.return_value = { + "status": "completed", + "export_path": {"local": "test.json"}, + "results": {"chunk1": [{"question": "test?", "solution": "test!"}]} + } + response = client.post("/synthesis/generate", json=request_data) + assert response.status_code == 200 + res_json = response.json() + assert res_json.get("status") == "completed" + assert "export_path" in res_json + +def test_generation_history(): + # Patch db_manager.get_paginated_generate_metadata to return dummy metadata with pagination info + db_manager.get_paginated_generate_metadata_light = lambda page, page_size: ( + 1, # total_count + [{"generate_file_name": "qa_pairs_claude_20250210T170521148_test.json", + "timestamp": "2024-02-10T12:00:00", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0"}] + ) + + # Since get_pending_generate_job_ids might be called, we should patch it too + db_manager.get_pending_generate_job_ids = lambda: [] + + response = client.get("/generations/history?page=1&page_size=10") + assert response.status_code == 200 + res_json = response.json() + + # Check that the response contains the expected structure + assert "data" in res_json + assert "pagination" in res_json + + # Check pagination metadata + assert res_json["pagination"]["total"] == 1 + assert res_json["pagination"]["page"] == 1 + assert res_json["pagination"]["page_size"] == 10 + assert res_json["pagination"]["total_pages"] == 1 + + # Check data contents + assert len(res_json["data"]) > 0 + # Instead of expecting exactly "test.json", assert the filename contains "_test.json" + assert "_test.json" in res_json["data"][0]["generate_file_name"] + +def test_error_handling(): + request_data = { + "use_case": "custom", + "model_id": "invalid.model", + "is_demo": True + } + response = client.post("/synthesis/generate", json=request_data) + # Expect an error with status code in [400,503] and key "error" + assert response.status_code in [400, 503] + res_json = response.json() + assert "error" in res_json diff --git a/tests/unit/test_evaluator_freeform_service.py b/tests/unit/test_evaluator_freeform_service.py new file mode 100644 index 00000000..38c8691d --- /dev/null +++ b/tests/unit/test_evaluator_freeform_service.py @@ -0,0 +1,80 @@ +import pytest +from unittest.mock import patch, Mock +import json +from app.services.evaluator_service import EvaluatorService +from app.models.request_models import EvaluationRequest +from tests.mocks.mock_db import MockDatabaseManager + +@pytest.fixture +def mock_freeform_data(): + return [{"field1": "value1", "field2": "value2", "field3": "value3"}] + +@pytest.fixture +def mock_freeform_file(tmp_path, mock_freeform_data): + file_path = tmp_path / "freeform_data.json" + with open(file_path, "w") as f: + json.dump(mock_freeform_data, f) + return str(file_path) + +@pytest.fixture +def evaluator_freeform_service(): + service = EvaluatorService() + service.db = MockDatabaseManager() + return service + +def test_evaluate_row_data(evaluator_freeform_service, mock_freeform_file): + request = EvaluationRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + use_case="custom", + import_path=mock_freeform_file, + is_demo=True, + output_key="field1", + output_value="field2" + ) + with patch('app.services.evaluator_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good freeform data"}] + result = evaluator_freeform_service.evaluate_row_data(request) + assert result["status"] == "completed" + assert "output_path" in result + assert len(evaluator_freeform_service.db.evaluation_metadata) == 1 + +def test_evaluate_single_row(evaluator_freeform_service): + with patch('app.services.evaluator_service.create_handler') as mock_handler: + mock_response = [{"score": 4, "justification": "Good freeform row"}] + mock_handler.return_value.generate_response.return_value = mock_response + + row = {"field1": "value1", "field2": "value2"} + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + is_demo=True, + output_key="field1", + output_value="field2" + ) + result = evaluator_freeform_service.evaluate_single_row(row, mock_handler.return_value, request) + assert result["evaluation"]["score"] == 4 + assert "justification" in result["evaluation"] + assert result["row"] == row + +def test_evaluate_rows(evaluator_freeform_service): + rows = [ + {"field1": "value1", "field2": "value2"}, + {"field1": "value3", "field2": "value4"} + ] + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + is_demo=True, + output_key="field1", + output_value="field2" + ) + + with patch('app.services.evaluator_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good row"}] + result = evaluator_freeform_service.evaluate_rows(rows, mock_handler.return_value, request) + + assert result["total_evaluated"] == 2 + assert result["average_score"] == 4 + assert len(result["evaluated_rows"]) == 2 diff --git a/tests/unit/test_evaluator_legacy_service.py b/tests/unit/test_evaluator_legacy_service.py new file mode 100644 index 00000000..c7dded81 --- /dev/null +++ b/tests/unit/test_evaluator_legacy_service.py @@ -0,0 +1,83 @@ +import pytest +from io import StringIO +from unittest.mock import patch +import json +from app.services.evaluator_legacy_service import EvaluatorLegacyService +from app.models.request_models import EvaluationRequest +from tests.mocks.mock_db import MockDatabaseManager +from app.core.exceptions import ModelHandlerError, APIError + +@pytest.fixture +def mock_qa_data(): + return [{"question": "test question?", "solution": "test solution"}] + +@pytest.fixture +def mock_qa_file(tmp_path, mock_qa_data): + file_path = tmp_path / "qa_pairs.json" + with open(file_path, "w") as f: + json.dump(mock_qa_data, f) + return str(file_path) + +@pytest.fixture +def evaluator_service(): + service = EvaluatorLegacyService() + service.db = MockDatabaseManager() + return service + +def test_evaluate_results(evaluator_service, mock_qa_file): + request = EvaluationRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + use_case="custom", + import_path=mock_qa_file, + is_demo=True, + output_key="Prompt", + output_value="Completion" + ) + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good answer"}] + result = evaluator_service.evaluate_results(request) + assert result["status"] == "completed" + assert "output_path" in result + assert len(evaluator_service.db.evaluation_metadata) == 1 + +def test_evaluate_single_pair(): + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_response = [{"score": 4, "justification": "Good explanation"}] + mock_handler.return_value.generate_response.return_value = mock_response + service = EvaluatorLegacyService() + qa_pair = {"Prompt": "What is Python?", "Completion": "Python is a programming language"} + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + is_demo=True, + output_key="Prompt", + output_value="Completion" + ) + result = service.evaluate_single_pair(qa_pair, mock_handler.return_value, request) + assert result["evaluation"]["score"] == 4 + assert "justification" in result["evaluation"] + +def test_evaluate_results_with_error(): + fake_json = '[{"Seeds": "python_basics", "Prompt": "What is Python?", "Completion": "Python is a programming language"}]' + class DummyHandler: + def generate_response(self, prompt, **kwargs): # Accept any keyword arguments + raise ModelHandlerError("Test error") + with patch('app.services.evaluator_legacy_service.os.path.exists', return_value=True), \ + patch('builtins.open', new=lambda f, mode, *args, **kwargs: StringIO(fake_json)), \ + patch('app.services.evaluator_legacy_service.create_handler', return_value=DummyHandler()), \ + patch('app.services.evaluator_legacy_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"): + service = EvaluatorLegacyService() + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + import_path="test.json", + is_demo=True, + output_key="Prompt", + output_value="Completion", + caii_endpoint="dummy_endpoint", + display_name="dummy" + ) + with pytest.raises(APIError, match="Test error"): + service.evaluate_results(request) diff --git a/tests/unit/test_evaluator_service.py b/tests/unit/test_evaluator_service.py index 629b80a3..c7dded81 100644 --- a/tests/unit/test_evaluator_service.py +++ b/tests/unit/test_evaluator_service.py @@ -2,7 +2,7 @@ from io import StringIO from unittest.mock import patch import json -from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import EvaluationRequest from tests.mocks.mock_db import MockDatabaseManager from app.core.exceptions import ModelHandlerError, APIError @@ -20,7 +20,7 @@ def mock_qa_file(tmp_path, mock_qa_data): @pytest.fixture def evaluator_service(): - service = EvaluatorService() + service = EvaluatorLegacyService() service.db = MockDatabaseManager() return service @@ -33,7 +33,7 @@ def test_evaluate_results(evaluator_service, mock_qa_file): output_key="Prompt", output_value="Completion" ) - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good answer"}] result = evaluator_service.evaluate_results(request) assert result["status"] == "completed" @@ -41,10 +41,10 @@ def test_evaluate_results(evaluator_service, mock_qa_file): assert len(evaluator_service.db.evaluation_metadata) == 1 def test_evaluate_single_pair(): - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_response = [{"score": 4, "justification": "Good explanation"}] mock_handler.return_value.generate_response.return_value = mock_response - service = EvaluatorService() + service = EvaluatorLegacyService() qa_pair = {"Prompt": "What is Python?", "Completion": "Python is a programming language"} request = EvaluationRequest( use_case="custom", @@ -63,11 +63,11 @@ def test_evaluate_results_with_error(): class DummyHandler: def generate_response(self, prompt, **kwargs): # Accept any keyword arguments raise ModelHandlerError("Test error") - with patch('app.services.evaluator_service.os.path.exists', return_value=True), \ + with patch('app.services.evaluator_legacy_service.os.path.exists', return_value=True), \ patch('builtins.open', new=lambda f, mode, *args, **kwargs: StringIO(fake_json)), \ - patch('app.services.evaluator_service.create_handler', return_value=DummyHandler()), \ - patch('app.services.evaluator_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"): - service = EvaluatorService() + patch('app.services.evaluator_legacy_service.create_handler', return_value=DummyHandler()), \ + patch('app.services.evaluator_legacy_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"): + service = EvaluatorLegacyService() request = EvaluationRequest( use_case="custom", model_id="test.model", diff --git a/tests/unit/test_synthesis_freeform_service.py b/tests/unit/test_synthesis_freeform_service.py new file mode 100644 index 00000000..f0bcb861 --- /dev/null +++ b/tests/unit/test_synthesis_freeform_service.py @@ -0,0 +1,70 @@ +import pytest +from unittest.mock import patch, Mock +import json +from app.services.synthesis_service import SynthesisService +from app.models.request_models import SynthesisRequest +from tests.mocks.mock_db import MockDatabaseManager + +@pytest.fixture +def mock_json_data(): + return [{"topic": "test_topic", "example_field": "test_value"}] + +@pytest.fixture +def mock_file(tmp_path, mock_json_data): + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + json.dump(mock_json_data, f) + return str(file_path) + +@pytest.fixture +def synthesis_freeform_service(): + service = SynthesisService() + service.db = MockDatabaseManager() + return service + +@pytest.mark.asyncio +async def test_generate_freeform_with_topics(synthesis_freeform_service): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=2, + topics=["test_topic"], + is_demo=True, + use_case="custom", + technique="freeform" + ) + with patch('app.services.synthesis_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"field1": "value1", "field2": "value2"}] + result = await synthesis_freeform_service.generate_freeform(request) + assert result["status"] == "completed" + assert len(synthesis_freeform_service.db.generation_metadata) == 1 + assert synthesis_freeform_service.db.generation_metadata[0]["model_id"] == request.model_id + +@pytest.mark.asyncio +async def test_generate_freeform_with_custom_examples(synthesis_freeform_service): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=1, + topics=["test_topic"], + is_demo=True, + use_case="custom", + technique="freeform", + example_custom=[{"example_field": "example_value"}] + ) + with patch('app.services.synthesis_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"generated_field": "generated_value"}] + result = await synthesis_freeform_service.generate_freeform(request) + assert result["status"] == "completed" + assert "export_path" in result + +def test_validate_freeform_item(synthesis_freeform_service): + # Valid freeform item + valid_item = {"field1": "value1", "field2": "value2"} + assert synthesis_freeform_service._validate_freeform_item(valid_item) == True + + # Invalid freeform item (empty dict) + invalid_item = {} + assert synthesis_freeform_service._validate_freeform_item(invalid_item) == False + + # Invalid freeform item (not a dict) + invalid_item = "not a dict" + assert synthesis_freeform_service._validate_freeform_item(invalid_item) == False diff --git a/tests/unit/test_synthesis_legacy_service.py b/tests/unit/test_synthesis_legacy_service.py new file mode 100644 index 00000000..120fad7c --- /dev/null +++ b/tests/unit/test_synthesis_legacy_service.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import patch, Mock +import json +from app.services.synthesis_legacy_service import SynthesisLegacyService +from app.models.request_models import SynthesisRequest +from tests.mocks.mock_db import MockDatabaseManager + +@pytest.fixture +def mock_json_data(): + return [{"input": "test question?"}] + +@pytest.fixture +def mock_file(tmp_path, mock_json_data): + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + json.dump(mock_json_data, f) + return str(file_path) + +@pytest.fixture +def synthesis_service(): + service = SynthesisLegacyService() + service.db = MockDatabaseManager() + return service + +@pytest.mark.asyncio +async def test_generate_examples_with_topics(synthesis_service): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=1, + topics=["test_topic"], + is_demo=True, + use_case="custom" + ) + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] + result = await synthesis_service.generate_examples(request) + assert result["status"] == "completed" + assert len(synthesis_service.db.generation_metadata) == 1 + assert synthesis_service.db.generation_metadata[0]["model_id"] == request.model_id + +@pytest.mark.asyncio +async def test_generate_examples_with_doc_paths(synthesis_service, mock_file): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=1, + doc_paths=[mock_file], + is_demo=True, + use_case="custom" + ) + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler, \ + patch('app.services.synthesis_legacy_service.DocumentProcessor') as mock_processor: + mock_processor.return_value.process_document.return_value = ["chunk1"] + mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] + result = await synthesis_service.generate_examples(request) + assert result["status"] == "completed" + assert len(synthesis_service.db.generation_metadata) == 1 diff --git a/tests/unit/test_synthesis_service.py b/tests/unit/test_synthesis_service.py index 6e9ca410..120fad7c 100644 --- a/tests/unit/test_synthesis_service.py +++ b/tests/unit/test_synthesis_service.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, Mock import json -from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService from app.models.request_models import SynthesisRequest from tests.mocks.mock_db import MockDatabaseManager @@ -18,7 +18,7 @@ def mock_file(tmp_path, mock_json_data): @pytest.fixture def synthesis_service(): - service = SynthesisService() + service = SynthesisLegacyService() service.db = MockDatabaseManager() return service @@ -31,7 +31,7 @@ async def test_generate_examples_with_topics(synthesis_service): is_demo=True, use_case="custom" ) - with patch('app.services.synthesis_service.create_handler') as mock_handler: + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] result = await synthesis_service.generate_examples(request) assert result["status"] == "completed" @@ -47,8 +47,8 @@ async def test_generate_examples_with_doc_paths(synthesis_service, mock_file): is_demo=True, use_case="custom" ) - with patch('app.services.synthesis_service.create_handler') as mock_handler, \ - patch('app.services.synthesis_service.DocumentProcessor') as mock_processor: + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler, \ + patch('app.services.synthesis_legacy_service.DocumentProcessor') as mock_processor: mock_processor.return_value.process_document.return_value = ["chunk1"] mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] result = await synthesis_service.generate_examples(request)