diff --git a/.gitignore b/.gitignore
index 0549d46a..a4dfe0da 100644
--- a/.gitignore
+++ b/.gitignore
@@ -55,6 +55,16 @@ qa_pairs*
Khauneesh/
*job_args*
+# Generated data files
+freeform_data_*.json
+row_data_*.json
+lending_*.json
+seeds_*.json
+SeedsInstructions.json
+*_example.json
+nm.json
+french_input.json
+
# DB
*metadata.db-shm
*metadata.db-wal
diff --git a/app/client/src/pages/DataGenerator/Finish.tsx b/app/client/src/pages/DataGenerator/Finish.tsx
index b54e5fee..60f7d4de 100644
--- a/app/client/src/pages/DataGenerator/Finish.tsx
+++ b/app/client/src/pages/DataGenerator/Finish.tsx
@@ -255,7 +255,8 @@ const Finish = () => {
title: 'Review Dataset',
description: 'Review your dataset to ensure it properly fits your usecase.',
icon: ,
- href: getFilesURL(genDatasetResp?.export_path?.local || "")
+ href: getFilesURL(genDatasetResp?.export_path?.local || ""),
+ external: true
},
{
avatar: '',
@@ -278,7 +279,8 @@ const Finish = () => {
title: 'Review Dataset',
description: 'Once your dataset finishes generating, you can review your dataset in the workbench files',
icon: ,
- href: getFilesURL('')
+ href: getFilesURL(''),
+ external: true
},
{
avatar: '',
@@ -361,7 +363,18 @@ const Finish = () => {
(
+ renderItem={({ title, href, icon, description, external }, i) => (
+ external ?
+
+
+ }
+ title={title}
+ description={description}
+ />
+
+ :
+
props.theme.color};
- background-color: ${props => props.theme.backgroundColor};
- border: 1px solid ${props => props.theme.borderColor};
+const StyledTag = styled(Tag)<{ $theme: { color: string; backgroundColor: string; borderColor: string } }>`
+ color: ${props => props.$theme.color} !important;
+ background-color: ${props => props.$theme.backgroundColor} !important;
+ border: 1px solid ${props => props.$theme.borderColor} !important;
`;
@@ -150,7 +150,7 @@ const TemplateCard: React.FC = ({ template }) => {
const { color, backgroundColor, borderColor } = getTemplateTagColors(theme as string);
return (
-
+
{tag}
diff --git a/app/core/model_handlers.py b/app/core/model_handlers.py
index 92391b8c..b61a0ec7 100644
--- a/app/core/model_handlers.py
+++ b/app/core/model_handlers.py
@@ -180,6 +180,8 @@ def generate_response(
return self._handle_caii_request(prompt)
if self.inference_type == "openai":
return self._handle_openai_request(prompt)
+ if self.inference_type == "openai_compatible":
+ return self._handle_openai_compatible_request(prompt)
if self.inference_type == "gemini":
return self._handle_gemini_request(prompt)
raise ModelHandlerError(f"Unsupported inference_type={self.inference_type}", 400)
@@ -342,6 +344,66 @@ def _handle_openai_request(self, prompt: str):
except Exception as e:
raise ModelHandlerError(f"OpenAI request failed: {e}", 500)
+ # ---------- OpenAI Compatible -------------------------------------------------------
+ def _handle_openai_compatible_request(self, prompt: str):
+ """Handle OpenAI compatible endpoints with proper timeout configuration"""
+ try:
+ import httpx
+ from openai import OpenAI
+
+ # Get API key from environment variable (only credential needed)
+ api_key = os.getenv('OpenAI_Endpoint_Compatible_Key')
+ if not api_key:
+ raise ModelHandlerError("OpenAI_Endpoint_Compatible_Key environment variable not set", 500)
+
+ # Base URL comes from caii_endpoint parameter (passed during initialization)
+ openai_compatible_endpoint = self.caii_endpoint
+ if not openai_compatible_endpoint:
+ raise ModelHandlerError("OpenAI compatible endpoint not provided", 500)
+
+ # Configure timeout for OpenAI compatible client (same as OpenAI v1.57.2)
+ timeout_config = httpx.Timeout(
+ connect=self.OPENAI_CONNECT_TIMEOUT,
+ read=self.OPENAI_READ_TIMEOUT,
+ write=10.0,
+ pool=5.0
+ )
+
+ # Configure httpx client with certificate verification for private cloud
+ if os.path.exists("/etc/ssl/certs/ca-certificates.crt"):
+ http_client = httpx.Client(
+ verify="/etc/ssl/certs/ca-certificates.crt",
+ timeout=timeout_config
+ )
+ else:
+ http_client = httpx.Client(timeout=timeout_config)
+
+ # Remove trailing '/chat/completions' if present (similar to CAII handling)
+ openai_compatible_endpoint = openai_compatible_endpoint.removesuffix('/chat/completions')
+
+ client = OpenAI(
+ api_key=api_key,
+ base_url=openai_compatible_endpoint,
+ http_client=http_client
+ )
+
+ completion = client.chat.completions.create(
+ model=self.model_id,
+ messages=[{"role": "user", "content": prompt}],
+ max_tokens=self.model_params.max_tokens,
+ temperature=self.model_params.temperature,
+ top_p=self.model_params.top_p,
+ stream=False,
+ )
+
+ print("generated via OpenAI Compatible endpoint")
+ response_text = completion.choices[0].message.content
+
+ return self._extract_json_from_text(response_text) if not self.custom_p else response_text
+
+ except Exception as e:
+ raise ModelHandlerError(f"OpenAI Compatible request failed: {str(e)}", status_code=500)
+
# ---------- Gemini -------------------------------------------------------
def _handle_gemini_request(self, prompt: str):
if genai is None:
diff --git a/app/main.py b/app/main.py
index b81bc05c..2becafb8 100644
--- a/app/main.py
+++ b/app/main.py
@@ -42,8 +42,10 @@
sys.path.append(str(ROOT_DIR))
from app.services.evaluator_service import EvaluatorService
+from app.services.evaluator_legacy_service import EvaluatorLegacyService
from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath, Technique
from app.services.synthesis_service import SynthesisService
+from app.services.synthesis_legacy_service import SynthesisLegacyService
from app.services.export_results import Export_Service
from app.core.prompt_templates import PromptBuilder, PromptHandler
@@ -66,8 +68,10 @@
#****************************************Initialize************************************************
# Initialize services
-synthesis_service = SynthesisService()
-evaluator_service = EvaluatorService()
+synthesis_service = SynthesisService() # Freeform only
+synthesis_legacy_service = SynthesisLegacyService() # SFT and Custom_Workflow
+evaluator_service = EvaluatorService() # Freeform only
+evaluator_legacy_service = EvaluatorLegacyService() # SFT and Custom_Workflow
export_service = Export_Service()
db_manager = DatabaseManager()
@@ -552,9 +556,11 @@ async def generate_examples(request: SynthesisRequest):
if is_demo== True:
if request.input_path:
- return await synthesis_service.generate_result(request,is_demo, request_id=request_id)
+ # Custom_Workflow technique - route to legacy service
+ return await synthesis_legacy_service.generate_result(request,is_demo, request_id=request_id)
else:
- return await synthesis_service.generate_examples(request,is_demo, request_id=request_id)
+ # SFT technique - route to legacy service
+ return await synthesis_legacy_service.generate_examples(request,is_demo, request_id=request_id)
else:
return synthesis_job.generate_job(request, core, mem, request_id=request_id)
@@ -626,7 +632,8 @@ async def evaluate_examples(request: EvaluationRequest):
is_demo = request.is_demo
if is_demo:
- return evaluator_service.evaluate_results(request, request_id=request_id)
+ # SFT and Custom_Workflow evaluation - route to legacy service
+ return evaluator_legacy_service.evaluate_results(request, request_id=request_id)
else:
return synthesis_job.evaluate_job(request, request_id=request_id)
@@ -1242,7 +1249,7 @@ def is_empty(self):
async def health_check():
"""Get API health status"""
#return {"status": "healthy"}
- return synthesis_service.get_health_check()
+ return synthesis_legacy_service.get_health_check()
@app.get("/{use_case}/example_payloads")
async def get_example_payloads(use_case:UseCase):
@@ -1255,6 +1262,7 @@ async def get_example_payloads(use_case:UseCase):
"technique": "sft",
"topics": ["python_basics", "data_structures"],
"is_demo": True,
+ "max_concurrent_topics": 5,
"examples": [
{
"question": "How do you create a list in Python and add elements to it?",
@@ -1281,6 +1289,7 @@ async def get_example_payloads(use_case:UseCase):
"technique": "sft",
"topics": ["basic_queries", "joins"],
"is_demo": True,
+ "max_concurrent_topics": 5,
"schema": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100), email VARCHAR(255));\nCREATE TABLE orders (id INT PRIMARY KEY, user_id INT, amount DECIMAL(10,2), FOREIGN KEY (user_id) REFERENCES users(id));",
"examples":[
{
@@ -1309,6 +1318,7 @@ async def get_example_payloads(use_case:UseCase):
"topics": ["topic 1", "topic 2"],
"custom_prompt": "Give your instructions here",
"is_demo": True,
+ "max_concurrent_topics": 5,
"examples":[
{
diff --git a/app/models/request_models.py b/app/models/request_models.py
index eaf9d7e0..bf22343f 100644
--- a/app/models/request_models.py
+++ b/app/models/request_models.py
@@ -123,6 +123,7 @@ class SynthesisRequest(BaseModel):
# Optional fields that can override defaults
inference_type: Optional[str] = "aws_bedrock"
caii_endpoint: Optional[str] = None
+ openai_compatible_endpoint: Optional[str] = None
topics: Optional[List[str]] = None
doc_paths: Optional[List[str]] = None
input_path: Optional[List[str]] = None
@@ -137,7 +138,13 @@ class SynthesisRequest(BaseModel):
example_path: Optional[str] = None
schema: Optional[str] = None # Added schema field
custom_prompt: Optional[str] = None
- display_name: Optional[str] = None
+ display_name: Optional[str] = None
+ max_concurrent_topics: Optional[int] = Field(
+ default=5,
+ ge=1,
+ le=100,
+ description="Maximum number of concurrent topics to process (1-100)"
+ )
# Optional model parameters with defaults
model_params: Optional[ModelParameters] = Field(
@@ -155,7 +162,7 @@ class SynthesisRequest(BaseModel):
"technique": "sft",
"topics": ["python_basics", "data_structures"],
"is_demo": True,
-
+ "max_concurrent_topics": 5
}
}
@@ -208,6 +215,12 @@ class EvaluationRequest(BaseModel):
display_name: Optional[str] = None
output_key: Optional[str] = 'Prompt'
output_value: Optional[str] = 'Completion'
+ max_workers: Optional[int] = Field(
+ default=4,
+ ge=1,
+ le=100,
+ description="Maximum number of worker threads for parallel evaluation (1-100)"
+ )
# Export configuration
export_type: str = "local" # "local" or "s3"
@@ -226,7 +239,8 @@ class EvaluationRequest(BaseModel):
"inference_type": "aws_bedrock",
"import_path": "qa_pairs_llama3-1-70b-instruct-v1:0_20241114_212837_test.json",
"import_type": "local",
- "export_type":"local"
+ "export_type":"local",
+ "max_workers": 4
}
}
diff --git a/app/run_eval_job.py b/app/run_eval_job.py
index 9591cee4..7a60714d 100644
--- a/app/run_eval_job.py
+++ b/app/run_eval_job.py
@@ -31,6 +31,7 @@
from app.models.request_models import EvaluationRequest, ModelParameters
from app.services.evaluator_service import EvaluatorService
+from app.services.evaluator_legacy_service import EvaluatorLegacyService
import asyncio
import nest_asyncio
@@ -40,7 +41,7 @@
async def run_eval(request, job_name, request_id):
try:
- job = EvaluatorService()
+ job = EvaluatorLegacyService()
result = job.evaluate_results(request,job_name, is_demo=False, request_id=request_id)
return result
except Exception as e:
diff --git a/app/run_job.py b/app/run_job.py
index 6d15833a..213c6d0a 100644
--- a/app/run_job.py
+++ b/app/run_job.py
@@ -32,6 +32,7 @@
import json
from app.models.request_models import SynthesisRequest
from app.services.synthesis_service import SynthesisService
+from app.services.synthesis_legacy_service import SynthesisLegacyService
import asyncio
import nest_asyncio # Add this import
@@ -41,7 +42,7 @@
async def run_synthesis(request, job_name, request_id):
"""Run standard synthesis job for question-answer pairs"""
try:
- job = SynthesisService()
+ job = SynthesisLegacyService()
if request.input_path:
result = await job.generate_result(request, job_name, is_demo=False, request_id=request_id)
else:
diff --git a/app/services/evaluator_legacy_service.py b/app/services/evaluator_legacy_service.py
new file mode 100644
index 00000000..328c1f83
--- /dev/null
+++ b/app/services/evaluator_legacy_service.py
@@ -0,0 +1,409 @@
+import boto3
+from typing import Dict, List, Optional, Any
+from typing import Dict, List, Optional
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from app.models.request_models import Example, ModelParameters, EvaluationRequest
+from app.core.model_handlers import create_handler
+from app.core.prompt_templates import PromptBuilder, PromptHandler
+from app.services.aws_bedrock import get_bedrock_client
+from app.core.database import DatabaseManager
+from app.core.config import UseCase, Technique, get_model_family
+from app.services.check_guardrail import ContentGuardrail
+from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError
+import os
+from datetime import datetime, timezone
+import json
+import logging
+from logging.handlers import RotatingFileHandler
+from app.core.telemetry_integration import track_llm_operation
+from functools import partial
+
+class EvaluatorLegacyService:
+ """Legacy service for evaluating generated QA pairs using Claude with parallel processing (SFT and Custom_Workflow only)"""
+
+ def __init__(self, max_workers: int = 4):
+ self.bedrock_client = get_bedrock_client()
+ self.db = DatabaseManager()
+ self.max_workers = max_workers # Default max workers (configurable via request)
+ self.guard = ContentGuardrail()
+ self._setup_logging()
+
+ def _setup_logging(self):
+ """Set up logging configuration"""
+ os.makedirs('logs', exist_ok=True)
+
+ self.logger = logging.getLogger('evaluator_legacy_service')
+ self.logger.setLevel(logging.INFO)
+
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+
+ # File handler for general logs
+ file_handler = RotatingFileHandler(
+ 'logs/evaluator_legacy_service.log',
+ maxBytes=10*1024*1024, # 10MB
+ backupCount=5
+ )
+ file_handler.setFormatter(formatter)
+ self.logger.addHandler(file_handler)
+
+ # File handler for errors
+ error_handler = RotatingFileHandler(
+ 'logs/evaluator_legacy_service_errors.log',
+ maxBytes=10*1024*1024,
+ backupCount=5
+ )
+ error_handler.setLevel(logging.ERROR)
+ error_handler.setFormatter(formatter)
+ self.logger.addHandler(error_handler)
+
+
+ #@track_llm_operation("evaluate_single_pair")
+ def evaluate_single_pair(self, qa_pair: Dict, model_handler, request: EvaluationRequest, request_id=None) -> Dict:
+ """Evaluate a single QA pair"""
+ try:
+ # Default error response
+ error_response = {
+ request.output_key: qa_pair.get(request.output_key, "Unknown"),
+ request.output_value: qa_pair.get(request.output_value, "Unknown"),
+ "evaluation": {
+ "score": 0,
+ "justification": "Error during evaluation"
+ }
+ }
+
+ try:
+ self.logger.info(f"Evaluating QA pair: {qa_pair.get(request.output_key, '')[:50]}...")
+ except Exception as e:
+ self.logger.error(f"Error logging QA pair: {str(e)}")
+
+ try:
+ # Validate input qa_pair structure
+ if not all(key in qa_pair for key in [request.output_key, request.output_value]):
+ error_msg = "Missing required keys in qa_pair"
+ self.logger.error(error_msg)
+ error_response["evaluation"]["justification"] = error_msg
+ return error_response
+
+ prompt = PromptBuilder.build_eval_prompt(
+ request.model_id,
+ request.use_case,
+ qa_pair[request.output_key],
+ qa_pair[request.output_value],
+ request.examples,
+ request.custom_prompt
+ )
+ #print(prompt)
+ except Exception as e:
+ error_msg = f"Error building evaluation prompt: {str(e)}"
+ self.logger.error(error_msg)
+ error_response["evaluation"]["justification"] = error_msg
+ return error_response
+
+ try:
+ response = model_handler.generate_response(prompt, request_id=request_id)
+ except ModelHandlerError as e:
+ self.logger.error(f"ModelHandlerError in generate_response: {str(e)}")
+ raise
+ except Exception as e:
+ error_msg = f"Error generating model response: {str(e)}"
+ self.logger.error(error_msg)
+ error_response["evaluation"]["justification"] = error_msg
+ return error_response
+
+ if not response:
+ error_msg = "Failed to parse model response"
+ self.logger.warning(error_msg)
+ error_response["evaluation"]["justification"] = error_msg
+ return error_response
+
+ try:
+ score = response[0].get('score', "no score key")
+ justification = response[0].get('justification', 'No justification provided')
+ if score== "no score key":
+ self.logger.info(f"Unsuccessful QA pair evaluation with score: {score}")
+ justification = "The evaluated pair did not generate valid score and justification"
+ score = 0
+ else:
+ self.logger.info(f"Successfully evaluated QA pair with score: {score}")
+
+ return {
+ "question": qa_pair[request.output_key],
+ "solution": qa_pair[request.output_value],
+ "evaluation": {
+ "score": score,
+ "justification": justification
+ }
+ }
+ except Exception as e:
+ error_msg = f"Error processing model response: {str(e)}"
+ self.logger.error(error_msg)
+ error_response["evaluation"]["justification"] = error_msg
+ return error_response
+
+ except ModelHandlerError:
+ raise
+ except Exception as e:
+ self.logger.error(f"Critical error in evaluate_single_pair: {str(e)}")
+ return error_response
+
+ #@track_llm_operation("evaluate_topic")
+ def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, request: EvaluationRequest, request_id=None) -> Dict:
+ """Evaluate all QA pairs for a given topic in parallel"""
+ try:
+ self.logger.info(f"Starting evaluation for topic: {topic} with {len(qa_pairs)} QA pairs")
+ evaluated_pairs = []
+ failed_pairs = []
+
+ try:
+ max_workers = request.max_workers or self.max_workers
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ try:
+ evaluate_func = partial(
+ self.evaluate_single_pair,
+ model_handler=model_handler,
+ request=request, request_id=request_id
+ )
+
+ future_to_pair = {
+ executor.submit(evaluate_func, pair): pair
+ for pair in qa_pairs
+ }
+
+ for future in as_completed(future_to_pair):
+ try:
+ result = future.result()
+ evaluated_pairs.append(result)
+ except ModelHandlerError:
+ raise
+ except Exception as e:
+ error_msg = f"Error processing future result: {str(e)}"
+ self.logger.error(error_msg)
+ failed_pairs.append({
+ "error": error_msg,
+ "pair": future_to_pair[future]
+ })
+
+ except Exception as e:
+ error_msg = f"Error in parallel execution: {str(e)}"
+ self.logger.error(error_msg)
+ raise
+
+ except ModelHandlerError:
+ raise
+ except Exception as e:
+ error_msg = f"Error in ThreadPoolExecutor setup: {str(e)}"
+ self.logger.error(error_msg)
+ raise
+
+ try:
+ # Calculate statistics only from successful evaluations
+ scores = [pair["evaluation"]["score"] for pair in evaluated_pairs if pair.get("evaluation", {}).get("score") is not None]
+
+ if scores:
+ average_score = sum(scores) / len(scores)
+ average_score = round(average_score, 2)
+ min_score = min(scores)
+ max_score = max(scores)
+ else:
+ average_score = min_score = max_score = 0
+
+ topic_stats = {
+ "average_score": average_score,
+ "min_score": min_score,
+ "max_score": max_score,
+ "evaluated_pairs": evaluated_pairs,
+ "failed_pairs": failed_pairs,
+ "total_evaluated": len(evaluated_pairs),
+ "total_failed": len(failed_pairs)
+ }
+
+ self.logger.info(f"Completed evaluation for topic: {topic}. Average score: {topic_stats['average_score']:.2f}")
+ return topic_stats
+
+ except Exception as e:
+ error_msg = f"Error calculating topic statistics: {str(e)}"
+ self.logger.error(error_msg)
+ return {
+ "average_score": 0,
+ "min_score": 0,
+ "max_score": 0,
+ "evaluated_pairs": evaluated_pairs,
+ "failed_pairs": failed_pairs,
+ "error": error_msg
+ }
+ except ModelHandlerError:
+ raise
+ except Exception as e:
+ error_msg = f"Critical error in evaluate_topic: {str(e)}"
+ self.logger.error(error_msg)
+ return {
+ "average_score": 0,
+ "min_score": 0,
+ "max_score": 0,
+ "evaluated_pairs": [],
+ "failed_pairs": [],
+ "error": error_msg
+ }
+
+ #@track_llm_operation("evaluate_results")
+ def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bool = True, request_id=None) -> Dict:
+ """Evaluate all QA pairs with parallel processing"""
+ try:
+ self.logger.info(f"Starting evaluation process - Demo Mode: {is_demo}")
+
+ model_params = request.model_params or ModelParameters()
+
+ self.logger.info(f"Creating model handler for model: {request.model_id}")
+ model_handler = create_handler(
+ request.model_id,
+ self.bedrock_client,
+ model_params=model_params,
+ inference_type = request.inference_type,
+ caii_endpoint = request.caii_endpoint
+ )
+
+ self.logger.info(f"Loading QA pairs from: {request.import_path}")
+ with open(request.import_path, 'r') as file:
+ data = json.load(file)
+
+ evaluated_results = {}
+ all_scores = []
+
+ transformed_data = {
+ "results": {},
+ }
+ for item in data:
+ topic = item.get('Seeds')
+
+ # Create topic list if it doesn't exist
+ if topic not in transformed_data['results']:
+ transformed_data['results'][topic] = []
+
+ # Create QA pair
+ qa_pair = {
+ request.output_key: item.get(request.output_key, ''), # Use get() with default value
+ request.output_value: item.get(request.output_value, '') # Use get() with default value
+ }
+
+ # Add to appropriate topic list
+ transformed_data['results'][topic].append(qa_pair)
+
+ max_workers = request.max_workers or self.max_workers
+ self.logger.info(f"Processing {len(transformed_data['results'])} topics with {max_workers} workers")
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ future_to_topic = {
+ executor.submit(
+ self.evaluate_topic,
+ topic,
+ qa_pairs,
+ model_handler,
+ request, request_id=request_id
+ ): topic
+ for topic, qa_pairs in transformed_data['results'].items()
+ }
+
+ for future in as_completed(future_to_topic):
+ try:
+ topic = future_to_topic[future]
+ topic_stats = future.result()
+ evaluated_results[topic] = topic_stats
+ all_scores.extend([
+ pair["evaluation"]["score"]
+ for pair in topic_stats["evaluated_pairs"]
+ ])
+ except ModelHandlerError as e:
+ self.logger.error(f"ModelHandlerError in future processing: {str(e)}")
+ raise APIError(f"Model evaluation failed: {str(e)}")
+
+
+ overall_average = sum(all_scores) / len(all_scores) if all_scores else 0
+ overall_average = round(overall_average, 2)
+ evaluated_results['Overall_Average'] = overall_average
+
+ self.logger.info(f"Evaluation completed. Overall average score: {overall_average:.2f}")
+
+
+ timestamp = datetime.now(timezone.utc).isoformat()
+ time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3]
+ model_name = get_model_family(request.model_id).split('.')[-1]
+ output_path = f"qa_pairs_{model_name}_{time_file}_evaluated.json"
+
+ self.logger.info(f"Saving evaluation results to: {output_path}")
+ with open(output_path, 'w') as f:
+ json.dump(evaluated_results, f, indent=2)
+
+ custom_prompt_str = PromptHandler.get_default_custom_eval_prompt(
+ request.use_case,
+ request.custom_prompt
+ )
+
+
+ examples_value = (
+ PromptHandler.get_default_eval_example(request.use_case, request.examples)
+ if hasattr(request, 'examples')
+ else None
+ )
+ examples_str = self.safe_json_dumps(examples_value)
+ #print(examples_value, '\n',examples_str)
+
+ metadata = {
+ 'timestamp': timestamp,
+ 'model_id': request.model_id,
+ 'inference_type': request.inference_type,
+ 'caii_endpoint':request.caii_endpoint,
+ 'use_case': request.use_case,
+ 'custom_prompt': custom_prompt_str,
+ 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None,
+ 'generate_file_name': os.path.basename(request.import_path),
+ 'evaluate_file_name': os.path.basename(output_path),
+ 'display_name': request.display_name,
+ 'local_export_path': output_path,
+ 'examples': examples_str,
+ 'Overall_Average': overall_average
+ }
+
+ self.logger.info("Saving evaluation metadata to database")
+
+
+ if is_demo:
+ self.db.save_evaluation_metadata(metadata)
+ return {
+ "status": "completed",
+ "result": evaluated_results,
+ "output_path": output_path
+ }
+ else:
+
+
+ job_status = "ENGINE_SUCCEEDED"
+ evaluate_file_name = os.path.basename(output_path)
+ self.db.update_job_evaluate(job_name, evaluate_file_name, output_path, timestamp, overall_average, job_status)
+ self.db.backup_and_restore_db()
+ return {
+ "status": "completed",
+ "output_path": output_path
+ }
+ except APIError:
+ raise
+ except ModelHandlerError as e:
+ # Add this specific handler
+ self.logger.error(f"ModelHandlerError in evaluation: {str(e)}")
+ raise APIError(str(e))
+ except Exception as e:
+ error_msg = f"Error in evaluation process: {str(e)}"
+ self.logger.error(error_msg, exc_info=True)
+ if is_demo:
+ raise APIError(str(e))
+ else:
+ time_stamp = datetime.now(timezone.utc).isoformat()
+ job_status = "ENGINE_FAILED"
+ file_name = ''
+ output_path = ''
+ overall_average = ''
+ self.db.update_job_evaluate(job_name,file_name, output_path, time_stamp, job_status)
+
+ raise
+
+ def safe_json_dumps(self, value):
+ """Convert value to JSON string only if it's not None"""
+ return json.dumps(value) if value is not None else None
diff --git a/app/services/evaluator_service.py b/app/services/evaluator_service.py
index 2b094b15..f3fbbcad 100644
--- a/app/services/evaluator_service.py
+++ b/app/services/evaluator_service.py
@@ -19,12 +19,12 @@
from functools import partial
class EvaluatorService:
- """Service for evaluating generated QA pairs using Claude with parallel processing"""
+ """Service for evaluating freeform data rows using Claude with parallel processing (Freeform technique only)"""
- def __init__(self, max_workers: int = 4):
+ def __init__(self, max_workers: int = 5):
self.bedrock_client = get_bedrock_client()
self.db = DatabaseManager()
- self.max_workers = max_workers
+ self.max_workers = max_workers # Default max workers (configurable via request)
self.guard = ContentGuardrail()
self._setup_logging()
@@ -56,351 +56,6 @@ def _setup_logging(self):
error_handler.setFormatter(formatter)
self.logger.addHandler(error_handler)
-
- #@track_llm_operation("evaluate_single_pair")
- def evaluate_single_pair(self, qa_pair: Dict, model_handler, request: EvaluationRequest, request_id=None) -> Dict:
- """Evaluate a single QA pair"""
- try:
- # Default error response
- error_response = {
- request.output_key: qa_pair.get(request.output_key, "Unknown"),
- request.output_value: qa_pair.get(request.output_value, "Unknown"),
- "evaluation": {
- "score": 0,
- "justification": "Error during evaluation"
- }
- }
-
- try:
- self.logger.info(f"Evaluating QA pair: {qa_pair.get(request.output_key, '')[:50]}...")
- except Exception as e:
- self.logger.error(f"Error logging QA pair: {str(e)}")
-
- try:
- # Validate input qa_pair structure
- if not all(key in qa_pair for key in [request.output_key, request.output_value]):
- error_msg = "Missing required keys in qa_pair"
- self.logger.error(error_msg)
- error_response["evaluation"]["justification"] = error_msg
- return error_response
-
- prompt = PromptBuilder.build_eval_prompt(
- request.model_id,
- request.use_case,
- qa_pair[request.output_key],
- qa_pair[request.output_value],
- request.examples,
- request.custom_prompt
- )
- #print(prompt)
- except Exception as e:
- error_msg = f"Error building evaluation prompt: {str(e)}"
- self.logger.error(error_msg)
- error_response["evaluation"]["justification"] = error_msg
- return error_response
-
- try:
- response = model_handler.generate_response(prompt, request_id=request_id)
- except ModelHandlerError as e:
- self.logger.error(f"ModelHandlerError in generate_response: {str(e)}")
- raise
- except Exception as e:
- error_msg = f"Error generating model response: {str(e)}"
- self.logger.error(error_msg)
- error_response["evaluation"]["justification"] = error_msg
- return error_response
-
- if not response:
- error_msg = "Failed to parse model response"
- self.logger.warning(error_msg)
- error_response["evaluation"]["justification"] = error_msg
- return error_response
-
- try:
- score = response[0].get('score', "no score key")
- justification = response[0].get('justification', 'No justification provided')
- if score== "no score key":
- self.logger.info(f"Unsuccessful QA pair evaluation with score: {score}")
- justification = "The evaluated pair did not generate valid score and justification"
- score = 0
- else:
- self.logger.info(f"Successfully evaluated QA pair with score: {score}")
-
- return {
- "question": qa_pair[request.output_key],
- "solution": qa_pair[request.output_value],
- "evaluation": {
- "score": score,
- "justification": justification
- }
- }
- except Exception as e:
- error_msg = f"Error processing model response: {str(e)}"
- self.logger.error(error_msg)
- error_response["evaluation"]["justification"] = error_msg
- return error_response
-
- except ModelHandlerError:
- raise
- except Exception as e:
- self.logger.error(f"Critical error in evaluate_single_pair: {str(e)}")
- return error_response
-
- #@track_llm_operation("evaluate_topic")
- def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, request: EvaluationRequest, request_id=None) -> Dict:
- """Evaluate all QA pairs for a given topic in parallel"""
- try:
- self.logger.info(f"Starting evaluation for topic: {topic} with {len(qa_pairs)} QA pairs")
- evaluated_pairs = []
- failed_pairs = []
-
- try:
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
- try:
- evaluate_func = partial(
- self.evaluate_single_pair,
- model_handler=model_handler,
- request=request, request_id=request_id
- )
-
- future_to_pair = {
- executor.submit(evaluate_func, pair): pair
- for pair in qa_pairs
- }
-
- for future in as_completed(future_to_pair):
- try:
- result = future.result()
- evaluated_pairs.append(result)
- except ModelHandlerError:
- raise
- except Exception as e:
- error_msg = f"Error processing future result: {str(e)}"
- self.logger.error(error_msg)
- failed_pairs.append({
- "error": error_msg,
- "pair": future_to_pair[future]
- })
-
- except Exception as e:
- error_msg = f"Error in parallel execution: {str(e)}"
- self.logger.error(error_msg)
- raise
-
- except ModelHandlerError:
- raise
- except Exception as e:
- error_msg = f"Error in ThreadPoolExecutor setup: {str(e)}"
- self.logger.error(error_msg)
- raise
-
- try:
- # Calculate statistics only from successful evaluations
- scores = [pair["evaluation"]["score"] for pair in evaluated_pairs if pair.get("evaluation", {}).get("score") is not None]
-
- if scores:
- average_score = sum(scores) / len(scores)
- average_score = round(average_score, 2)
- min_score = min(scores)
- max_score = max(scores)
- else:
- average_score = min_score = max_score = 0
-
- topic_stats = {
- "average_score": average_score,
- "min_score": min_score,
- "max_score": max_score,
- "evaluated_pairs": evaluated_pairs,
- "failed_pairs": failed_pairs,
- "total_evaluated": len(evaluated_pairs),
- "total_failed": len(failed_pairs)
- }
-
- self.logger.info(f"Completed evaluation for topic: {topic}. Average score: {topic_stats['average_score']:.2f}")
- return topic_stats
-
- except Exception as e:
- error_msg = f"Error calculating topic statistics: {str(e)}"
- self.logger.error(error_msg)
- return {
- "average_score": 0,
- "min_score": 0,
- "max_score": 0,
- "evaluated_pairs": evaluated_pairs,
- "failed_pairs": failed_pairs,
- "error": error_msg
- }
- except ModelHandlerError:
- raise
- except Exception as e:
- error_msg = f"Critical error in evaluate_topic: {str(e)}"
- self.logger.error(error_msg)
- return {
- "average_score": 0,
- "min_score": 0,
- "max_score": 0,
- "evaluated_pairs": [],
- "failed_pairs": [],
- "error": error_msg
- }
- #@track_llm_operation("evaluate_results")
- def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bool = True, request_id=None) -> Dict:
- """Evaluate all QA pairs with parallel processing"""
- try:
- self.logger.info(f"Starting evaluation process - Demo Mode: {is_demo}")
-
- model_params = request.model_params or ModelParameters()
-
- self.logger.info(f"Creating model handler for model: {request.model_id}")
- model_handler = create_handler(
- request.model_id,
- self.bedrock_client,
- model_params=model_params,
- inference_type = request.inference_type,
- caii_endpoint = request.caii_endpoint
- )
-
- self.logger.info(f"Loading QA pairs from: {request.import_path}")
- with open(request.import_path, 'r') as file:
- data = json.load(file)
-
- evaluated_results = {}
- all_scores = []
-
- transformed_data = {
- "results": {},
- }
- for item in data:
- topic = item.get('Seeds')
-
- # Create topic list if it doesn't exist
- if topic not in transformed_data['results']:
- transformed_data['results'][topic] = []
-
- # Create QA pair
- qa_pair = {
- request.output_key: item.get(request.output_key, ''), # Use get() with default value
- request.output_value: item.get(request.output_value, '') # Use get() with default value
- }
-
- # Add to appropriate topic list
- transformed_data['results'][topic].append(qa_pair)
-
- self.logger.info(f"Processing {len(transformed_data['results'])} topics with {self.max_workers} workers")
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
- future_to_topic = {
- executor.submit(
- self.evaluate_topic,
- topic,
- qa_pairs,
- model_handler,
- request, request_id=request_id
- ): topic
- for topic, qa_pairs in transformed_data['results'].items()
- }
-
- for future in as_completed(future_to_topic):
- try:
- topic = future_to_topic[future]
- topic_stats = future.result()
- evaluated_results[topic] = topic_stats
- all_scores.extend([
- pair["evaluation"]["score"]
- for pair in topic_stats["evaluated_pairs"]
- ])
- except ModelHandlerError as e:
- self.logger.error(f"ModelHandlerError in future processing: {str(e)}")
- raise APIError(f"Model evaluation failed: {str(e)}")
-
-
- overall_average = sum(all_scores) / len(all_scores) if all_scores else 0
- overall_average = round(overall_average, 2)
- evaluated_results['Overall_Average'] = overall_average
-
- self.logger.info(f"Evaluation completed. Overall average score: {overall_average:.2f}")
-
-
- timestamp = datetime.now(timezone.utc).isoformat()
- time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3]
- model_name = get_model_family(request.model_id).split('.')[-1]
- output_path = f"qa_pairs_{model_name}_{time_file}_evaluated.json"
-
- self.logger.info(f"Saving evaluation results to: {output_path}")
- with open(output_path, 'w') as f:
- json.dump(evaluated_results, f, indent=2)
-
- custom_prompt_str = PromptHandler.get_default_custom_eval_prompt(
- request.use_case,
- request.custom_prompt
- )
-
-
- examples_value = (
- PromptHandler.get_default_eval_example(request.use_case, request.examples)
- if hasattr(request, 'examples')
- else None
- )
- examples_str = self.safe_json_dumps(examples_value)
- #print(examples_value, '\n',examples_str)
-
- metadata = {
- 'timestamp': timestamp,
- 'model_id': request.model_id,
- 'inference_type': request.inference_type,
- 'caii_endpoint':request.caii_endpoint,
- 'use_case': request.use_case,
- 'custom_prompt': custom_prompt_str,
- 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None,
- 'generate_file_name': os.path.basename(request.import_path),
- 'evaluate_file_name': os.path.basename(output_path),
- 'display_name': request.display_name,
- 'local_export_path': output_path,
- 'examples': examples_str,
- 'Overall_Average': overall_average
- }
-
- self.logger.info("Saving evaluation metadata to database")
-
-
- if is_demo:
- self.db.save_evaluation_metadata(metadata)
- return {
- "status": "completed",
- "result": evaluated_results,
- "output_path": output_path
- }
- else:
-
-
- job_status = "ENGINE_SUCCEEDED"
- evaluate_file_name = os.path.basename(output_path)
- self.db.update_job_evaluate(job_name, evaluate_file_name, output_path, timestamp, overall_average, job_status)
- self.db.backup_and_restore_db()
- return {
- "status": "completed",
- "output_path": output_path
- }
- except APIError:
- raise
- except ModelHandlerError as e:
- # Add this specific handler
- self.logger.error(f"ModelHandlerError in evaluation: {str(e)}")
- raise APIError(str(e))
- except Exception as e:
- error_msg = f"Error in evaluation process: {str(e)}"
- self.logger.error(error_msg, exc_info=True)
- if is_demo:
- raise APIError(str(e))
- else:
- time_stamp = datetime.now(timezone.utc).isoformat()
- job_status = "ENGINE_FAILED"
- file_name = ''
- output_path = ''
- overall_average = ''
- self.db.update_job_evaluate(job_name,file_name, output_path, time_stamp, job_status)
-
- raise
-
def evaluate_single_row(self, row: Dict[str, Any], model_handler, request: EvaluationRequest, request_id = None) -> Dict:
"""Evaluate a single data row"""
try:
@@ -488,7 +143,8 @@ def evaluate_rows(self, rows: List[Dict[str, Any]], model_handler, request: Eval
failed_rows = []
try:
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
+ max_workers = request.max_workers or self.max_workers
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
try:
evaluate_func = partial(
self.evaluate_single_row,
@@ -690,4 +346,4 @@ def evaluate_row_data(self, request: EvaluationRequest, job_name=None, is_demo:
def safe_json_dumps(self, value):
"""Convert value to JSON string only if it's not None"""
- return json.dumps(value) if value is not None else None
\ No newline at end of file
+ return json.dumps(value) if value is not None else None
diff --git a/app/services/model_alignment.py b/app/services/model_alignment.py
index 650005a3..2b3fcfe1 100644
--- a/app/services/model_alignment.py
+++ b/app/services/model_alignment.py
@@ -7,8 +7,8 @@
import asyncio
from datetime import datetime, timezone
from typing import Dict, Optional
-from app.services.synthesis_service import SynthesisService
-from app.services.evaluator_service import EvaluatorService
+from app.services.synthesis_legacy_service import SynthesisLegacyService
+from app.services.evaluator_legacy_service import EvaluatorLegacyService
from app.models.request_models import SynthesisRequest, EvaluationRequest
from app.models.request_models import ModelParameters
from app.services.aws_bedrock import get_bedrock_client
@@ -21,8 +21,8 @@ class ModelAlignment:
"""Service for aligning model outputs through synthesis and evaluation"""
def __init__(self):
- self.synthesis_service = SynthesisService()
- self.evaluator_service = EvaluatorService()
+ self.synthesis_service = SynthesisLegacyService()
+ self.evaluator_service = EvaluatorLegacyService()
self.db = DatabaseManager()
self.bedrock_client = get_bedrock_client() # Add this line
self._setup_logging()
diff --git a/app/services/synthesis_job.py b/app/services/synthesis_job.py
index 323fa7f0..8ea42010 100644
--- a/app/services/synthesis_job.py
+++ b/app/services/synthesis_job.py
@@ -3,9 +3,9 @@
import uuid
import os
from typing import Dict, Any, Optional
-from app.services.evaluator_service import EvaluatorService
+from app.services.evaluator_legacy_service import EvaluatorLegacyService
from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath
-from app.services.synthesis_service import SynthesisService
+from app.services.synthesis_legacy_service import SynthesisLegacyService
from app.services.export_results import Export_Service
from app.core.prompt_templates import PromptBuilder, PromptHandler
from app.core.config import UseCase, USE_CASE_CONFIGS
@@ -22,8 +22,8 @@
import cmlapi
# Initialize services
-synthesis_service = SynthesisService()
-evaluator_service = EvaluatorService()
+synthesis_service = SynthesisLegacyService()
+evaluator_service = EvaluatorLegacyService()
export_service = Export_Service()
db_manager = DatabaseManager()
diff --git a/app/services/synthesis_legacy_service.py b/app/services/synthesis_legacy_service.py
new file mode 100644
index 00000000..62df587c
--- /dev/null
+++ b/app/services/synthesis_legacy_service.py
@@ -0,0 +1,701 @@
+import boto3
+import json
+import uuid
+import time
+import csv
+from typing import List, Dict, Optional, Tuple
+import uuid
+from datetime import datetime, timezone
+import os
+from huggingface_hub import HfApi, HfFolder, Repository
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+import math
+import asyncio
+from fastapi import FastAPI, BackgroundTasks, HTTPException
+from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError, JSONParsingError
+from app.core.data_loader import DataLoader
+import pandas as pd
+import numpy as np
+
+from app.models.request_models import SynthesisRequest, Example, ModelParameters
+from app.core.model_handlers import create_handler
+from app.core.prompt_templates import PromptBuilder, PromptHandler
+from app.core.config import UseCase, Technique, get_model_family
+from app.services.aws_bedrock import get_bedrock_client
+from app.core.database import DatabaseManager
+from app.services.check_guardrail import ContentGuardrail
+from app.services.doc_extraction import DocumentProcessor
+import logging
+from logging.handlers import RotatingFileHandler
+import traceback
+from app.core.telemetry_integration import track_llm_operation
+import uuid
+
+
+class SynthesisLegacyService:
+ """Legacy service for generating synthetic QA pairs (SFT and Custom_Workflow only)"""
+ QUESTIONS_PER_BATCH = 5 # Maximum questions per batch
+ MAX_CONCURRENT_TOPICS = 5 # Default limit for concurrent I/O operations (configurable via request)
+
+
+ def __init__(self):
+ self.bedrock_client = get_bedrock_client()
+ self.db = DatabaseManager()
+ self._setup_logging()
+ self.guard = ContentGuardrail()
+
+
+ def _setup_logging(self):
+ """Set up logging configuration"""
+ os.makedirs('logs', exist_ok=True)
+
+ self.logger = logging.getLogger('synthesis_legacy_service')
+ self.logger.setLevel(logging.INFO)
+
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+
+ # File handler for general logs
+ file_handler = RotatingFileHandler(
+ 'logs/synthesis_legacy_service.log',
+ maxBytes=10*1024*1024, # 10MB
+ backupCount=5
+ )
+ file_handler.setFormatter(formatter)
+ self.logger.addHandler(file_handler)
+
+ # File handler for errors
+ error_handler = RotatingFileHandler(
+ 'logs/synthesis_legacy_service_errors.log',
+ maxBytes=10*1024*1024,
+ backupCount=5
+ )
+ error_handler.setLevel(logging.ERROR)
+ error_handler.setFormatter(formatter)
+ self.logger.addHandler(error_handler)
+
+
+ #@track_llm_operation("process_single_topic")
+ def process_single_topic(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]:
+ """
+ Process a single topic to generate questions and solutions.
+ Attempts batch processing first (default 5 questions), falls back to single question processing if batch fails.
+
+ Args:
+ topic: The topic to generate questions for
+ model_handler: Handler for the AI model
+ request: The synthesis request object
+ num_questions: Total number of questions to generate
+
+ Returns:
+ Tuple containing:
+ - topic (str)
+ - list of validated QA pairs
+ - list of error messages
+ - list of output dictionaries with topic information
+
+ Raises:
+ ModelHandlerError: When there's an error in model generation that should stop processing
+ """
+ topic_results = []
+ topic_output = []
+ topic_errors = []
+ questions_remaining = num_questions
+ omit_questions = []
+
+ try:
+ # Process questions in batches
+ for batch_idx in range(0, num_questions, self.QUESTIONS_PER_BATCH):
+ if questions_remaining <= 0:
+ break
+
+ batch_size = min(self.QUESTIONS_PER_BATCH, questions_remaining)
+ self.logger.info(f"Processing topic: {topic}, attempting batch {batch_idx+1}-{batch_idx+batch_size}")
+
+ try:
+ # Attempt batch processing
+ prompt = PromptBuilder.build_prompt(
+ model_id=request.model_id,
+ use_case=request.use_case,
+ topic=topic,
+ num_questions=batch_size,
+ omit_questions=omit_questions,
+ examples=request.examples or [],
+ technique=request.technique,
+ schema=request.schema,
+ custom_prompt=request.custom_prompt,
+ )
+ # print("prompt :", prompt)
+ batch_qa_pairs = None
+ try:
+ batch_qa_pairs = model_handler.generate_response(prompt, request_id=request_id)
+ except ModelHandlerError as e:
+ self.logger.warning(f"Batch processing failed: {str(e)}")
+ if isinstance(e, JSONParsingError):
+ # For JSON parsing errors, fall back to single processing
+ self.logger.info("JSON parsing failed, falling back to single processing")
+ continue
+ else:
+ # For other model errors, propagate up
+ raise
+
+ if batch_qa_pairs:
+ # Process batch results
+ valid_pairs = []
+ valid_outputs = []
+ invalid_count = 0
+
+ for pair in batch_qa_pairs:
+ if self._validate_qa_pair(pair):
+ valid_pairs.append({
+ "question": pair["question"],
+ "solution": pair["solution"]
+ })
+ valid_outputs.append({
+ "Topic": topic,
+ "question": pair["question"],
+ "solution": pair["solution"]
+ })
+ omit_questions.append(pair["question"])
+ #else:
+ invalid_count = batch_size - len(valid_pairs)
+
+ if valid_pairs:
+ topic_results.extend(valid_pairs)
+ topic_output.extend(valid_outputs)
+ questions_remaining -= len(valid_pairs)
+ omit_questions = omit_questions[-100:] # Keep last 100 questions
+ self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}")
+ print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs))
+ # If all pairs were valid, skip fallback
+ if invalid_count <= 0:
+ continue
+
+ else:
+ # Fall back to single processing for remaining or failed questions
+ self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}")
+ remaining_batch = invalid_count
+ print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs))
+ for _ in range(remaining_batch):
+ if questions_remaining <= 0:
+ break
+
+ try:
+ # Single question processing
+ prompt = PromptBuilder.build_prompt(
+ model_id=request.model_id,
+ use_case=request.use_case,
+ topic=topic,
+ num_questions=1,
+ omit_questions=omit_questions,
+ examples=request.examples or [],
+ technique=request.technique,
+ schema=request.schema,
+ custom_prompt=request.custom_prompt,
+ )
+
+ try:
+ single_qa_pairs = model_handler.generate_response(prompt, request_id=request_id)
+ except ModelHandlerError as e:
+ self.logger.warning(f"Batch processing failed: {str(e)}")
+ if isinstance(e, JSONParsingError):
+ # For JSON parsing errors, fall back to single processing
+ self.logger.info("JSON parsing failed, falling back to single processing")
+ continue
+ else:
+ # For other model errors, propagate up
+ raise
+
+ if single_qa_pairs and len(single_qa_pairs) > 0:
+ pair = single_qa_pairs[0]
+ if self._validate_qa_pair(pair):
+ validated_pair = {
+ "question": pair["question"],
+ "solution": pair["solution"]
+ }
+ validated_output = {
+ "Topic": topic,
+ "question": pair["question"],
+ "solution": pair["solution"]
+ }
+
+ topic_results.append(validated_pair)
+ topic_output.append(validated_output)
+ omit_questions.append(pair["question"])
+ omit_questions = omit_questions[-100:]
+ questions_remaining -= 1
+
+ self.logger.info(f"Successfully generated single question for topic {topic}")
+ else:
+ error_msg = f"Invalid QA pair structure in single processing for topic {topic}"
+ self.logger.warning(error_msg)
+ topic_errors.append(error_msg)
+ else:
+ error_msg = f"No QA pair generated in single processing for topic {topic}"
+ self.logger.warning(error_msg)
+ topic_errors.append(error_msg)
+
+ except ModelHandlerError as e:
+ # Don't raise - add to errors and continue
+ error_msg = f"ModelHandlerError in single processing for topic {topic}: {str(e)}"
+ self.logger.error(error_msg)
+ topic_errors.append(error_msg)
+ continue
+
+ except ModelHandlerError:
+ # Re-raise ModelHandlerError to propagate up
+ raise
+ except Exception as e:
+ error_msg = f"Error processing batch for topic {topic}: {str(e)}"
+ self.logger.error(error_msg)
+ topic_errors.append(error_msg)
+ continue
+
+ except ModelHandlerError:
+ # Re-raise ModelHandlerError to propagate up
+ raise
+ except Exception as e:
+ error_msg = f"Critical error processing topic {topic}: {str(e)}"
+ self.logger.error(error_msg)
+ topic_errors.append(error_msg)
+
+ return topic, topic_results, topic_errors, topic_output
+
+
+ async def generate_examples(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id= None) -> Dict:
+ """Generate examples based on request parameters (SFT technique)"""
+ try:
+ output_key = request.output_key
+ output_value = request.output_value
+ st = time.time()
+ self.logger.info(f"Starting example generation - Demo Mode: {is_demo}")
+
+ # Use default parameters if none provided
+ model_params = request.model_params or ModelParameters()
+
+ # Create model handler
+ self.logger.info("Creating model handler")
+ model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint)
+
+ # Limit topics and questions in demo mode
+ if request.doc_paths:
+ processor = DocumentProcessor(chunk_size=1000, overlap=100)
+ paths = request.doc_paths
+ topics = []
+ for path in paths:
+ chunks = processor.process_document(path)
+ topics.extend(chunks)
+ #topics = topics[0:5]
+ print("total chunks: ", len(topics))
+ if request.num_questions<=len(topics):
+ topics = topics[0:request.num_questions]
+ num_questions = 1
+ print("num_questions :", num_questions)
+ else:
+ num_questions = math.ceil(request.num_questions/len(topics))
+ #print(num_questions)
+ total_count = request.num_questions
+ else:
+ if request.topics:
+ topics = request.topics
+ num_questions = request.num_questions
+ total_count = request.num_questions*len(request.topics)
+
+ else:
+ self.logger.error("Generation failed: No topics provided")
+ raise RuntimeError("Invalid input: No topics provided")
+
+
+ # Track results for each topic
+ results = {}
+ all_errors = []
+ final_output = []
+
+ # Create thread pool
+ loop = asyncio.get_event_loop()
+ max_workers = request.max_concurrent_topics or self.MAX_CONCURRENT_TOPICS
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ topic_futures = [
+ loop.run_in_executor(
+ executor,
+ self.process_single_topic,
+ topic,
+ model_handler,
+ request,
+ num_questions,
+ request_id
+ )
+ for topic in topics
+ ]
+
+ # Wait for all futures to complete
+ try:
+ completed_topics = await asyncio.gather(*topic_futures)
+ except ModelHandlerError as e:
+ self.logger.error(f"Model generation failed: {str(e)}")
+ raise APIError(f"Failed to generate content: {str(e)}")
+
+ # Process results
+
+ for topic, topic_results, topic_errors, topic_output in completed_topics:
+ if topic_errors:
+ all_errors.extend(topic_errors)
+ if topic_results and is_demo:
+ results[topic] = topic_results
+ if topic_output:
+ final_output.extend(topic_output)
+
+ generation_time = time.time() - st
+ self.logger.info(f"Generation completed in {generation_time:.2f} seconds")
+
+ timestamp = datetime.now(timezone.utc).isoformat()
+ time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3]
+ mode_suffix = "test" if is_demo else "final"
+ model_name = get_model_family(request.model_id).split('.')[-1]
+ file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json"
+ if request.doc_paths:
+ final_output = [{
+ 'Generated_From': item['Topic'],
+ output_key: item['question'],
+ output_value: item['solution'] }
+ for item in final_output]
+ else:
+ final_output = [{
+ 'Seeds': item['Topic'],
+ output_key: item['question'],
+ output_value: item['solution'] }
+ for item in final_output]
+ output_path = {}
+ try:
+ with open(file_path, "w") as f:
+ json.dump(final_output, indent=2, fp=f)
+ except Exception as e:
+ self.logger.error(f"Error saving results: {str(e)}", exc_info=True)
+
+ output_path['local']= file_path
+
+
+
+
+ # Handle custom prompt, examples and schema
+ custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt)
+ # For examples
+ examples_value = (
+ PromptHandler.get_default_example(request.use_case, request.examples)
+ if hasattr(request, 'examples')
+ else None
+ )
+ examples_str = self.safe_json_dumps(examples_value)
+
+ # For schema
+ schema_value = (
+ PromptHandler.get_default_schema(request.use_case, request.schema)
+ if hasattr(request, 'schema')
+ else None
+ )
+ schema_str = self.safe_json_dumps(schema_value)
+
+ # For topics
+ topics_value = topics if not getattr(request, 'doc_paths', None) else None
+ topic_str = self.safe_json_dumps(topics_value)
+
+ # For doc_paths and input_path
+ doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None))
+ input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None))
+
+ metadata = {
+ 'timestamp': timestamp,
+ 'technique': request.technique,
+ 'model_id': request.model_id,
+ 'inference_type': request.inference_type,
+ 'caii_endpoint':request.caii_endpoint,
+ 'use_case': request.use_case,
+ 'final_prompt': custom_prompt_str,
+ 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None,
+ 'generate_file_name': os.path.basename(output_path['local']),
+ 'display_name': request.display_name,
+ 'output_path': output_path,
+ 'num_questions':getattr(request, 'num_questions', None),
+ 'topics': topic_str,
+ 'examples': examples_str,
+ "total_count":total_count,
+ 'schema': schema_str,
+ 'doc_paths': doc_paths_str,
+ 'input_path':input_path_str,
+ 'input_key': request.input_key,
+ 'output_key':request.output_key,
+ 'output_value':request.output_value
+ }
+
+ #print("metadata: ",metadata)
+ if is_demo:
+
+ self.db.save_generation_metadata(metadata)
+ return {
+ "status": "completed" if results else "failed",
+ "results": results,
+ "errors": all_errors if all_errors else None,
+ "export_path": output_path
+ }
+ else:
+ # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1])
+ # time_stamp = extract_timestamp(metadata.get('generate_file_name'))
+ job_status = "ENGINE_SUCCEEDED"
+ generate_file_name = os.path.basename(output_path['local'])
+
+ self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status)
+ self.db.backup_and_restore_db()
+ return {
+ "status": "completed" if final_output else "failed",
+ "export_path": output_path
+ }
+ except APIError:
+ raise
+
+ except Exception as e:
+ self.logger.error(f"Generation failed: {str(e)}", exc_info=True)
+ if is_demo:
+ raise APIError(str(e)) # Let middleware decide status code
+ else:
+ time_stamp = datetime.now(timezone.utc).isoformat()
+ job_status = "ENGINE_FAILED"
+ file_name = ''
+ output_path = ''
+ self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status)
+ raise # Just re-raise the original exception
+
+
+ def _validate_qa_pair(self, pair: Dict) -> bool:
+ """Validate a question-answer pair"""
+ return (
+ isinstance(pair, dict) and
+ "question" in pair and
+ "solution" in pair and
+ isinstance(pair["question"], str) and
+ isinstance(pair["solution"], str) and
+ len(pair["question"].strip()) > 0 and
+ len(pair["solution"].strip()) > 0
+ )
+
+ #@track_llm_operation("process_single_input")
+ async def process_single_input(self, input, model_handler, request, request_id=None):
+ try:
+ prompt = PromptBuilder.build_generate_result_prompt(
+ model_id=request.model_id,
+ use_case=request.use_case,
+ input=input,
+ examples=request.examples or [],
+ schema=request.schema,
+ custom_prompt=request.custom_prompt,
+ )
+ try:
+ result = model_handler.generate_response(prompt, request_id=request_id)
+ except ModelHandlerError as e:
+ self.logger.error(f"ModelHandlerError in generate_response: {str(e)}")
+ raise
+
+ return {"question": input, "solution": result}
+
+ except ModelHandlerError:
+ raise
+ except Exception as e:
+ self.logger.error(f"Error processing input: {str(e)}")
+ raise APIError(f"Failed to process input: {str(e)}")
+
+ async def generate_result(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id=None) -> Dict:
+ """Generate results based on request parameters (Custom_Workflow technique)"""
+ try:
+ self.logger.info(f"Starting example generation - Demo Mode: {is_demo}")
+
+
+ # Use default parameters if none provided
+ model_params = request.model_params or ModelParameters()
+
+ # Create model handler
+ self.logger.info("Creating model handler")
+ model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint, custom_p = True)
+
+ inputs = []
+ file_paths = request.input_path
+ for path in file_paths:
+ try:
+ with open(path) as f:
+ data = json.load(f)
+ inputs.extend(item.get(request.input_key, '') for item in data)
+ except Exception as e:
+ print(f"Error processing {path}: {str(e)}")
+ MAX_WORKERS = 5
+
+
+ # Create thread pool
+ loop = asyncio.get_event_loop()
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
+ # Create futures for each input
+ input_futures = [
+ loop.run_in_executor(
+ executor,
+ lambda x: asyncio.run(self.process_single_input(x, model_handler, request, request_id)),
+ input
+ )
+ for input in inputs
+ ]
+
+ # Wait for all futures to complete
+ try:
+ final_output = await asyncio.gather(*input_futures)
+ except ModelHandlerError as e:
+ self.logger.error(f"Model generation failed: {str(e)}")
+ raise APIError(f"Failed to generate content: {str(e)}")
+
+
+
+
+ timestamp = datetime.now(timezone.utc).isoformat()
+ time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3]
+ mode_suffix = "test" if is_demo else "final"
+ model_name = get_model_family(request.model_id).split('.')[-1]
+ file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json"
+ input_key = request.output_key or request.input_key
+ result = [{
+
+ input_key: item['question'],
+ request.output_value: item['solution'] }
+ for item in final_output]
+ output_path = {}
+ try:
+ with open(file_path, "w") as f:
+ json.dump(result, indent=2, fp=f)
+ except Exception as e:
+ self.logger.error(f"Error saving results: {str(e)}", exc_info=True)
+
+
+
+
+ output_path['local']= file_path
+
+
+ # Handle custom prompt, examples and schema
+ custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt)
+ # For examples
+ examples_value = (
+ PromptHandler.get_default_example(request.use_case, request.examples)
+ if hasattr(request, 'examples')
+ else None
+ )
+ examples_str = self.safe_json_dumps(examples_value)
+
+ # For schema
+ schema_value = (
+ PromptHandler.get_default_schema(request.use_case, request.schema)
+ if hasattr(request, 'schema')
+ else None
+ )
+ schema_str = self.safe_json_dumps(schema_value)
+
+ # For topics
+ topics_value = None
+ topic_str = self.safe_json_dumps(topics_value)
+
+ # For doc_paths and input_path
+ doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None))
+ input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None))
+
+
+
+ metadata = {
+ 'timestamp': timestamp,
+ 'technique': request.technique,
+ 'model_id': request.model_id,
+ 'inference_type': request.inference_type,
+ 'caii_endpoint':request.caii_endpoint,
+ 'use_case': request.use_case,
+ 'final_prompt': custom_prompt_str,
+ 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None,
+ 'generate_file_name': os.path.basename(output_path['local']),
+ 'display_name': request.display_name,
+ 'output_path': output_path,
+ 'num_questions':getattr(request, 'num_questions', None),
+ 'topics': topic_str,
+ 'examples': examples_str,
+ "total_count":len(inputs),
+ 'schema': schema_str,
+ 'doc_paths': doc_paths_str,
+ 'input_path':input_path_str,
+ 'input_key': request.input_key,
+ 'output_key':request.output_key,
+ 'output_value':request.output_value
+ }
+
+
+ if is_demo:
+
+ self.db.save_generation_metadata(metadata)
+ return {
+ "status": "completed" if final_output else "failed",
+ "results": final_output,
+ "export_path": output_path
+ }
+ else:
+ # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1])
+ # time_stamp = extract_timestamp(metadata.get('generate_file_name'))
+ job_status = "success"
+ generate_file_name = os.path.basename(output_path['local'])
+
+ self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status)
+ self.db.backup_and_restore_db()
+ return {
+ "status": "completed" if final_output else "failed",
+ "export_path": output_path
+ }
+
+ except APIError:
+ raise
+ except Exception as e:
+ self.logger.error(f"Generation failed: {str(e)}", exc_info=True)
+ if is_demo:
+ raise APIError(str(e)) # Let middleware decide status code
+ else:
+ time_stamp = datetime.now(timezone.utc).isoformat()
+ job_status = "failure"
+ file_name = ''
+ output_path = ''
+ self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status)
+ raise # Just re-raise the original exception
+
+ def get_health_check(self) -> Dict:
+ """Get service health status"""
+ try:
+ test_body = {
+ "prompt": "\n\nHuman: test\n\nAssistant: ",
+ "max_tokens_to_sample": 1,
+ "temperature": 0
+ }
+
+ self.bedrock_client.invoke_model(
+ modelId="anthropic.claude-instant-v1",
+ body=json.dumps(test_body)
+ )
+
+ status = {
+ "status": "healthy",
+ "timestamp": datetime.now().isoformat(),
+ "service": "SynthesisLegacyService",
+ "aws_region": self.bedrock_client.meta.region_name
+ }
+ self.logger.info("Health check passed", extra=status)
+ return status
+
+ except Exception as e:
+ status = {
+ "status": "unhealthy",
+ "error": str(e),
+ "timestamp": datetime.now().isoformat(),
+ "service": "SynthesisLegacyService",
+ "aws_region": self.bedrock_client.meta.region_name
+ }
+ self.logger.error("Health check failed", extra=status, exc_info=True)
+ return status
+
+ def safe_json_dumps(self, value):
+ """Convert value to JSON string only if it's not None"""
+ return json.dumps(value) if value is not None else None
diff --git a/app/services/synthesis_service.py b/app/services/synthesis_service.py
index e42ac3c8..6f8e643e 100644
--- a/app/services/synthesis_service.py
+++ b/app/services/synthesis_service.py
@@ -34,19 +34,13 @@
class SynthesisService:
- """Service for generating synthetic QA pairs"""
+ """Service for generating synthetic freeform data (Freeform technique only)"""
QUESTIONS_PER_BATCH = 5 # Maximum questions per batch
- MAX_CONCURRENT_TOPICS = 5 # Limit concurrent I/O operations
+ MAX_CONCURRENT_TOPICS = 5 # Default limit for concurrent I/O operations (configurable via request)
def __init__(self):
- # self.bedrock_client = boto3.Session(profile_name='cu_manowar_dev').client(
- # 'bedrock-runtime',
- # region_name='us-west-2'
- # )
- self.bedrock_client = get_bedrock_client()
-
-
+ self.bedrock_client = get_bedrock_client()
self.db = DatabaseManager()
self._setup_logging()
self.guard = ContentGuardrail()
@@ -80,592 +74,6 @@ def _setup_logging(self):
error_handler.setFormatter(formatter)
self.logger.addHandler(error_handler)
-
- #@track_llm_operation("process_single_topic")
- def process_single_topic(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]:
- """
- Process a single topic to generate questions and solutions.
- Attempts batch processing first (default 5 questions), falls back to single question processing if batch fails.
-
- Args:
- topic: The topic to generate questions for
- model_handler: Handler for the AI model
- request: The synthesis request object
- num_questions: Total number of questions to generate
-
- Returns:
- Tuple containing:
- - topic (str)
- - list of validated QA pairs
- - list of error messages
- - list of output dictionaries with topic information
-
- Raises:
- ModelHandlerError: When there's an error in model generation that should stop processing
- """
- topic_results = []
- topic_output = []
- topic_errors = []
- questions_remaining = num_questions
- omit_questions = []
-
- try:
- # Process questions in batches
- for batch_idx in range(0, num_questions, self.QUESTIONS_PER_BATCH):
- if questions_remaining <= 0:
- break
-
- batch_size = min(self.QUESTIONS_PER_BATCH, questions_remaining)
- self.logger.info(f"Processing topic: {topic}, attempting batch {batch_idx+1}-{batch_idx+batch_size}")
-
- try:
- # Attempt batch processing
- prompt = PromptBuilder.build_prompt(
- model_id=request.model_id,
- use_case=request.use_case,
- topic=topic,
- num_questions=batch_size,
- omit_questions=omit_questions,
- examples=request.examples or [],
- technique=request.technique,
- schema=request.schema,
- custom_prompt=request.custom_prompt,
- )
- # print("prompt :", prompt)
- batch_qa_pairs = None
- try:
- batch_qa_pairs = model_handler.generate_response(prompt, request_id=request_id)
- except ModelHandlerError as e:
- self.logger.warning(f"Batch processing failed: {str(e)}")
- if isinstance(e, JSONParsingError):
- # For JSON parsing errors, fall back to single processing
- self.logger.info("JSON parsing failed, falling back to single processing")
- continue
- else:
- # For other model errors, propagate up
- raise
-
- if batch_qa_pairs:
- # Process batch results
- valid_pairs = []
- valid_outputs = []
- invalid_count = 0
-
- for pair in batch_qa_pairs:
- if self._validate_qa_pair(pair):
- valid_pairs.append({
- "question": pair["question"],
- "solution": pair["solution"]
- })
- valid_outputs.append({
- "Topic": topic,
- "question": pair["question"],
- "solution": pair["solution"]
- })
- omit_questions.append(pair["question"])
- #else:
- invalid_count = batch_size - len(valid_pairs)
-
- if valid_pairs:
- topic_results.extend(valid_pairs)
- topic_output.extend(valid_outputs)
- questions_remaining -= len(valid_pairs)
- omit_questions = omit_questions[-100:] # Keep last 100 questions
- self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}")
- print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs))
- # If all pairs were valid, skip fallback
- if invalid_count <= 0:
- continue
-
- else:
- # Fall back to single processing for remaining or failed questions
- self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}")
- remaining_batch = invalid_count
- print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs))
- for _ in range(remaining_batch):
- if questions_remaining <= 0:
- break
-
- try:
- # Single question processing
- prompt = PromptBuilder.build_prompt(
- model_id=request.model_id,
- use_case=request.use_case,
- topic=topic,
- num_questions=1,
- omit_questions=omit_questions,
- examples=request.examples or [],
- technique=request.technique,
- schema=request.schema,
- custom_prompt=request.custom_prompt,
- )
-
- try:
- single_qa_pairs = model_handler.generate_response(prompt, request_id=request_id)
- except ModelHandlerError as e:
- self.logger.warning(f"Batch processing failed: {str(e)}")
- if isinstance(e, JSONParsingError):
- # For JSON parsing errors, fall back to single processing
- self.logger.info("JSON parsing failed, falling back to single processing")
- continue
- else:
- # For other model errors, propagate up
- raise
-
- if single_qa_pairs and len(single_qa_pairs) > 0:
- pair = single_qa_pairs[0]
- if self._validate_qa_pair(pair):
- validated_pair = {
- "question": pair["question"],
- "solution": pair["solution"]
- }
- validated_output = {
- "Topic": topic,
- "question": pair["question"],
- "solution": pair["solution"]
- }
-
- topic_results.append(validated_pair)
- topic_output.append(validated_output)
- omit_questions.append(pair["question"])
- omit_questions = omit_questions[-100:]
- questions_remaining -= 1
-
- self.logger.info(f"Successfully generated single question for topic {topic}")
- else:
- error_msg = f"Invalid QA pair structure in single processing for topic {topic}"
- self.logger.warning(error_msg)
- topic_errors.append(error_msg)
- else:
- error_msg = f"No QA pair generated in single processing for topic {topic}"
- self.logger.warning(error_msg)
- topic_errors.append(error_msg)
-
- except ModelHandlerError as e:
- # Don't raise - add to errors and continue
- error_msg = f"ModelHandlerError in single processing for topic {topic}: {str(e)}"
- self.logger.error(error_msg)
- topic_errors.append(error_msg)
- continue
-
- except ModelHandlerError:
- # Re-raise ModelHandlerError to propagate up
- raise
- except Exception as e:
- error_msg = f"Error processing batch for topic {topic}: {str(e)}"
- self.logger.error(error_msg)
- topic_errors.append(error_msg)
- continue
-
- except ModelHandlerError:
- # Re-raise ModelHandlerError to propagate up
- raise
- except Exception as e:
- error_msg = f"Critical error processing topic {topic}: {str(e)}"
- self.logger.error(error_msg)
- topic_errors.append(error_msg)
-
- return topic, topic_results, topic_errors, topic_output
-
-
- async def generate_examples(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id= None) -> Dict:
- """Generate examples based on request parameters"""
- try:
- output_key = request.output_key
- output_value = request.output_value
- st = time.time()
- self.logger.info(f"Starting example generation - Demo Mode: {is_demo}")
-
- # Use default parameters if none provided
- model_params = request.model_params or ModelParameters()
-
- # Create model handler
- self.logger.info("Creating model handler")
- model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint)
-
- # Limit topics and questions in demo mode
- if request.doc_paths:
- processor = DocumentProcessor(chunk_size=1000, overlap=100)
- paths = request.doc_paths
- topics = []
- for path in paths:
- chunks = processor.process_document(path)
- topics.extend(chunks)
- #topics = topics[0:5]
- print("total chunks: ", len(topics))
- if request.num_questions<=len(topics):
- topics = topics[0:request.num_questions]
- num_questions = 1
- print("num_questions :", num_questions)
- else:
- num_questions = math.ceil(request.num_questions/len(topics))
- #print(num_questions)
- total_count = request.num_questions
- else:
- if request.topics:
- topics = request.topics
- num_questions = request.num_questions
- total_count = request.num_questions*len(request.topics)
-
- else:
- self.logger.error("Generation failed: No topics provided")
- raise RuntimeError("Invalid input: No topics provided")
-
-
- # Track results for each topic
- results = {}
- all_errors = []
- final_output = []
-
- # Create thread pool
- loop = asyncio.get_event_loop()
- with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor:
- topic_futures = [
- loop.run_in_executor(
- executor,
- self.process_single_topic,
- topic,
- model_handler,
- request,
- num_questions,
- request_id
- )
- for topic in topics
- ]
-
- # Wait for all futures to complete
- try:
- completed_topics = await asyncio.gather(*topic_futures)
- except ModelHandlerError as e:
- self.logger.error(f"Model generation failed: {str(e)}")
- raise APIError(f"Failed to generate content: {str(e)}")
-
- # Process results
-
- for topic, topic_results, topic_errors, topic_output in completed_topics:
- if topic_errors:
- all_errors.extend(topic_errors)
- if topic_results and is_demo:
- results[topic] = topic_results
- if topic_output:
- final_output.extend(topic_output)
-
- generation_time = time.time() - st
- self.logger.info(f"Generation completed in {generation_time:.2f} seconds")
-
- timestamp = datetime.now(timezone.utc).isoformat()
- time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3]
- mode_suffix = "test" if is_demo else "final"
- model_name = get_model_family(request.model_id).split('.')[-1]
- file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json"
- if request.doc_paths:
- final_output = [{
- 'Generated_From': item['Topic'],
- output_key: item['question'],
- output_value: item['solution'] }
- for item in final_output]
- else:
- final_output = [{
- 'Seeds': item['Topic'],
- output_key: item['question'],
- output_value: item['solution'] }
- for item in final_output]
- output_path = {}
- try:
- with open(file_path, "w") as f:
- json.dump(final_output, indent=2, fp=f)
- except Exception as e:
- self.logger.error(f"Error saving results: {str(e)}", exc_info=True)
-
- output_path['local']= file_path
-
-
-
-
- # Handle custom prompt, examples and schema
- custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt)
- # For examples
- examples_value = (
- PromptHandler.get_default_example(request.use_case, request.examples)
- if hasattr(request, 'examples')
- else None
- )
- examples_str = self.safe_json_dumps(examples_value)
-
- # For schema
- schema_value = (
- PromptHandler.get_default_schema(request.use_case, request.schema)
- if hasattr(request, 'schema')
- else None
- )
- schema_str = self.safe_json_dumps(schema_value)
-
- # For topics
- topics_value = topics if not getattr(request, 'doc_paths', None) else None
- topic_str = self.safe_json_dumps(topics_value)
-
- # For doc_paths and input_path
- doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None))
- input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None))
-
- metadata = {
- 'timestamp': timestamp,
- 'technique': request.technique,
- 'model_id': request.model_id,
- 'inference_type': request.inference_type,
- 'caii_endpoint':request.caii_endpoint,
- 'use_case': request.use_case,
- 'final_prompt': custom_prompt_str,
- 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None,
- 'generate_file_name': os.path.basename(output_path['local']),
- 'display_name': request.display_name,
- 'output_path': output_path,
- 'num_questions':getattr(request, 'num_questions', None),
- 'topics': topic_str,
- 'examples': examples_str,
- "total_count":total_count,
- 'schema': schema_str,
- 'doc_paths': doc_paths_str,
- 'input_path':input_path_str,
- 'input_key': request.input_key,
- 'output_key':request.output_key,
- 'output_value':request.output_value
- }
-
- #print("metadata: ",metadata)
- if is_demo:
-
- self.db.save_generation_metadata(metadata)
- return {
- "status": "completed" if results else "failed",
- "results": results,
- "errors": all_errors if all_errors else None,
- "export_path": output_path
- }
- else:
- # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1])
- # time_stamp = extract_timestamp(metadata.get('generate_file_name'))
- job_status = "ENGINE_SUCCEEDED"
- generate_file_name = os.path.basename(output_path['local'])
-
- self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status)
- self.db.backup_and_restore_db()
- return {
- "status": "completed" if final_output else "failed",
- "export_path": output_path
- }
- except APIError:
- raise
-
- except Exception as e:
- self.logger.error(f"Generation failed: {str(e)}", exc_info=True)
- if is_demo:
- raise APIError(str(e)) # Let middleware decide status code
- else:
- time_stamp = datetime.now(timezone.utc).isoformat()
- job_status = "ENGINE_FAILED"
- file_name = ''
- output_path = ''
- self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status)
- raise # Just re-raise the original exception
-
-
- def _validate_qa_pair(self, pair: Dict) -> bool:
- """Validate a question-answer pair"""
- return (
- isinstance(pair, dict) and
- "question" in pair and
- "solution" in pair and
- isinstance(pair["question"], str) and
- isinstance(pair["solution"], str) and
- len(pair["question"].strip()) > 0 and
- len(pair["solution"].strip()) > 0
- )
- #@track_llm_operation("process_single_input")
- async def process_single_input(self, input, model_handler, request, request_id=None):
- try:
- prompt = PromptBuilder.build_generate_result_prompt(
- model_id=request.model_id,
- use_case=request.use_case,
- input=input,
- examples=request.examples or [],
- schema=request.schema,
- custom_prompt=request.custom_prompt,
- )
- try:
- result = model_handler.generate_response(prompt, request_id=request_id)
- except ModelHandlerError as e:
- self.logger.error(f"ModelHandlerError in generate_response: {str(e)}")
- raise
-
- return {"question": input, "solution": result}
-
- except ModelHandlerError:
- raise
- except Exception as e:
- self.logger.error(f"Error processing input: {str(e)}")
- raise APIError(f"Failed to process input: {str(e)}")
-
- async def generate_result(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id=None) -> Dict:
- try:
- self.logger.info(f"Starting example generation - Demo Mode: {is_demo}")
-
-
- # Use default parameters if none provided
- model_params = request.model_params or ModelParameters()
-
- # Create model handler
- self.logger.info("Creating model handler")
- model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint, custom_p = True)
-
- inputs = []
- file_paths = request.input_path
- for path in file_paths:
- try:
- with open(path) as f:
- data = json.load(f)
- inputs.extend(item.get(request.input_key, '') for item in data)
- except Exception as e:
- print(f"Error processing {path}: {str(e)}")
- MAX_WORKERS = 5
-
-
- # Create thread pool
- loop = asyncio.get_event_loop()
- with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
- # Create futures for each input
- input_futures = [
- loop.run_in_executor(
- executor,
- lambda x: asyncio.run(self.process_single_input(x, model_handler, request, request_id)),
- input
- )
- for input in inputs
- ]
-
- # Wait for all futures to complete
- try:
- final_output = await asyncio.gather(*input_futures)
- except ModelHandlerError as e:
- self.logger.error(f"Model generation failed: {str(e)}")
- raise APIError(f"Failed to generate content: {str(e)}")
-
-
-
-
- timestamp = datetime.now(timezone.utc).isoformat()
- time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3]
- mode_suffix = "test" if is_demo else "final"
- model_name = get_model_family(request.model_id).split('.')[-1]
- file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json"
- input_key = request.output_key or request.input_key
- result = [{
-
- input_key: item['question'],
- request.output_value: item['solution'] }
- for item in final_output]
- output_path = {}
- try:
- with open(file_path, "w") as f:
- json.dump(result, indent=2, fp=f)
- except Exception as e:
- self.logger.error(f"Error saving results: {str(e)}", exc_info=True)
-
-
-
-
-
- output_path['local']= file_path
-
-
- # Handle custom prompt, examples and schema
- custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt)
- # For examples
- examples_value = (
- PromptHandler.get_default_example(request.use_case, request.examples)
- if hasattr(request, 'examples')
- else None
- )
- examples_str = self.safe_json_dumps(examples_value)
-
- # For schema
- schema_value = (
- PromptHandler.get_default_schema(request.use_case, request.schema)
- if hasattr(request, 'schema')
- else None
- )
- schema_str = self.safe_json_dumps(schema_value)
-
- # For topics
- topics_value = None
- topic_str = self.safe_json_dumps(topics_value)
-
- # For doc_paths and input_path
- doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None))
- input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None))
-
-
-
- metadata = {
- 'timestamp': timestamp,
- 'technique': request.technique,
- 'model_id': request.model_id,
- 'inference_type': request.inference_type,
- 'caii_endpoint':request.caii_endpoint,
- 'use_case': request.use_case,
- 'final_prompt': custom_prompt_str,
- 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None,
- 'generate_file_name': os.path.basename(output_path['local']),
- 'display_name': request.display_name,
- 'output_path': output_path,
- 'num_questions':getattr(request, 'num_questions', None),
- 'topics': topic_str,
- 'examples': examples_str,
- "total_count":len(inputs),
- 'schema': schema_str,
- 'doc_paths': doc_paths_str,
- 'input_path':input_path_str,
- 'input_key': request.input_key,
- 'output_key':request.output_key,
- 'output_value':request.output_value
- }
-
-
- if is_demo:
-
- self.db.save_generation_metadata(metadata)
- return {
- "status": "completed" if final_output else "failed",
- "results": final_output,
- "export_path": output_path
- }
- else:
- # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1])
- # time_stamp = extract_timestamp(metadata.get('generate_file_name'))
- job_status = "success"
- generate_file_name = os.path.basename(output_path['local'])
-
- self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status)
- self.db.backup_and_restore_db()
- return {
- "status": "completed" if final_output else "failed",
- "export_path": output_path
- }
-
- except APIError:
- raise
- except Exception as e:
- self.logger.error(f"Generation failed: {str(e)}", exc_info=True)
- if is_demo:
- raise APIError(str(e)) # Let middleware decide status code
- else:
- time_stamp = datetime.now(timezone.utc).isoformat()
- job_status = "failure"
- file_name = ''
- output_path = ''
- self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status)
- raise # Just re-raise the original exception
-
#@track_llm_operation("process_single_freeform")
def process_single_freeform(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]:
"""
@@ -960,7 +368,8 @@ async def generate_freeform(self, request: SynthesisRequest, job_name=None, is_d
# Create thread pool
loop = asyncio.get_event_loop()
- with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor:
+ max_workers = request.max_concurrent_topics or self.MAX_CONCURRENT_TOPICS
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
topic_futures = [
loop.run_in_executor(
executor,
@@ -1209,4 +618,4 @@ def get_health_check(self) -> Dict:
def safe_json_dumps(self, value):
"""Convert value to JSON string only if it's not None"""
- return json.dumps(value) if value is not None else None
\ No newline at end of file
+ return json.dumps(value) if value is not None else None
diff --git a/tests/integration/test_evaluate_api.py b/tests/integration/test_evaluate_api.py
index a842cd55..2ff664a9 100644
--- a/tests/integration/test_evaluate_api.py
+++ b/tests/integration/test_evaluate_api.py
@@ -3,7 +3,7 @@
import json
from fastapi.testclient import TestClient
from pathlib import Path
-from app.main import app, db_manager, evaluator_service # global instance created on import
+from app.main import app, db_manager, evaluator_legacy_service # global instance created on import
client = TestClient(app)
# Create a dummy bedrock client that simulates the Converse/invoke_model response.
@@ -37,8 +37,8 @@ def mock_qa_file(tmp_path, mock_qa_data):
# Patch the global evaluator_service's AWS client before tests run.
@pytest.fixture(autouse=True)
def patch_evaluator_bedrock_client():
- from app.main import evaluator_service
- evaluator_service.bedrock_client = DummyBedrockClient()
+ from app.main import evaluator_legacy_service
+ evaluator_legacy_service.bedrock_client = DummyBedrockClient()
def test_evaluate_endpoint(mock_qa_file):
request_data = {
@@ -51,7 +51,7 @@ def test_evaluate_endpoint(mock_qa_file):
"output_value": "Completion"
}
# Optionally, patch create_handler to return a dummy handler that returns a dummy evaluation.
- with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}]
response = client.post("/synthesis/evaluate", json=request_data)
# In demo mode, our endpoint returns a dict with "status", "result", and "output_path".
@@ -71,7 +71,7 @@ def test_job_handling(mock_qa_file):
"output_key": "Prompt",
"output_value": "Completion"
}
- with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}]
response = client.post("/synthesis/evaluate", json=request_data)
assert response.status_code == 200
diff --git a/tests/integration/test_evaluate_legacy_api.py b/tests/integration/test_evaluate_legacy_api.py
new file mode 100644
index 00000000..2ff664a9
--- /dev/null
+++ b/tests/integration/test_evaluate_legacy_api.py
@@ -0,0 +1,113 @@
+import pytest
+from unittest.mock import patch, Mock
+import json
+from fastapi.testclient import TestClient
+from pathlib import Path
+from app.main import app, db_manager, evaluator_legacy_service # global instance created on import
+client = TestClient(app)
+
+# Create a dummy bedrock client that simulates the Converse/invoke_model response.
+class DummyBedrockClient:
+ def invoke_model(self, modelId, body):
+ # Return a dummy response structure (adjust if your handler expects a different format)
+ return [{
+ "score": 1.0,
+ "justification": "Dummy response from invoke_model"
+ }]
+ @property
+ def meta(self):
+ class Meta:
+ region_name = "us-west-2"
+ return Meta()
+
+@pytest.fixture
+def mock_qa_data():
+ return [
+ {"Seeds": "python_basics", "Prompt": "What is Python?", "Completion": "Python is a programming language"},
+ {"Seeds": "python_basics", "Prompt": "How do you define a function?", "Completion": "Use the def keyword followed by function name"}
+ ]
+
+@pytest.fixture
+def mock_qa_file(tmp_path, mock_qa_data):
+ file_path = tmp_path / "qa_pairs.json"
+ with open(file_path, "w") as f:
+ json.dump(mock_qa_data, f)
+ return str(file_path)
+
+# Patch the global evaluator_service's AWS client before tests run.
+@pytest.fixture(autouse=True)
+def patch_evaluator_bedrock_client():
+ from app.main import evaluator_legacy_service
+ evaluator_legacy_service.bedrock_client = DummyBedrockClient()
+
+def test_evaluate_endpoint(mock_qa_file):
+ request_data = {
+ "use_case": "custom",
+ "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ "inference_type": "aws_bedrock",
+ "import_path": mock_qa_file,
+ "is_demo": True,
+ "output_key": "Prompt",
+ "output_value": "Completion"
+ }
+ # Optionally, patch create_handler to return a dummy handler that returns a dummy evaluation.
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}]
+ response = client.post("/synthesis/evaluate", json=request_data)
+ # In demo mode, our endpoint returns a dict with "status", "result", and "output_path".
+ assert response.status_code == 200
+ res_json = response.json()
+ assert res_json["status"] == "completed"
+ assert "output_path" in res_json
+ assert "result" in res_json
+
+def test_job_handling(mock_qa_file):
+ request_data = {
+ "use_case": "custom",
+ "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ "inference_type": "aws_bedrock",
+ "import_path": mock_qa_file,
+ "is_demo": True,
+ "output_key": "Prompt",
+ "output_value": "Completion"
+ }
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}]
+ response = client.post("/synthesis/evaluate", json=request_data)
+ assert response.status_code == 200
+ res_json = response.json()
+ # In demo mode, we don't expect a "job_id" key; we check for "output_path" and "result".
+ assert "output_path" in res_json
+ # Now simulate history by patching db_manager.get_all_evaluate_metadata
+ db_manager.get_all_evaluate_metadata = lambda: [{"evaluate_file_name": "test.json", "average_score": 0.9}]
+ response = client.get("/evaluations/history")
+ assert response.status_code == 200
+ history = response.json()
+ assert len(history) > 0
+
+def test_evaluate_with_invalid_model(mock_qa_file):
+ request_data = {
+ "use_case": "custom",
+ "model_id": "invalid.model",
+ "inference_type": "aws_bedrock",
+ "import_path": mock_qa_file,
+ "is_demo": True,
+ "output_key": "Prompt",
+ "output_value": "Completion"
+ }
+
+ from app.core.exceptions import ModelHandlerError
+
+ # Patch create_handler to raise ModelHandlerError
+ with patch('app.services.evaluator_service.create_handler') as mock_create:
+ mock_create.side_effect = ModelHandlerError("Invalid model identifier: invalid.model")
+ response = client.post("/synthesis/evaluate", json=request_data)
+
+ # Print for debugging
+ print(f"Response status: {response.status_code}")
+ print(f"Response content: {response.json()}")
+
+ # Expect a 400 or 500 error response
+ assert response.status_code in [400, 500]
+ res_json = response.json()
+ assert "error" in res_json
diff --git a/tests/integration/test_synthesis_api.py b/tests/integration/test_synthesis_api.py
index cbd60e9f..db54d9c7 100644
--- a/tests/integration/test_synthesis_api.py
+++ b/tests/integration/test_synthesis_api.py
@@ -14,7 +14,7 @@ def test_generate_endpoint_with_topics():
"topics": ["python_basics"],
"is_demo": True
}
- with patch('app.main.SynthesisService.generate_examples') as mock_generate:
+ with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate:
mock_generate.return_value = {
"status": "completed",
"export_path": {"local": "test.json"},
@@ -35,7 +35,7 @@ def test_generate_endpoint_with_doc_paths():
"doc_paths": ["test.pdf"],
"is_demo": True
}
- with patch('app.main.SynthesisService.generate_examples') as mock_generate:
+ with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate:
mock_generate.return_value = {
"status": "completed",
"export_path": {"local": "test.json"},
diff --git a/tests/integration/test_synthesis_legacy_api.py b/tests/integration/test_synthesis_legacy_api.py
new file mode 100644
index 00000000..db54d9c7
--- /dev/null
+++ b/tests/integration/test_synthesis_legacy_api.py
@@ -0,0 +1,91 @@
+import pytest
+from unittest.mock import patch
+import json
+from fastapi.testclient import TestClient
+from app.main import app, db_manager
+client = TestClient(app)
+
+def test_generate_endpoint_with_topics():
+ request_data = {
+ "use_case": "custom",
+ "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ "inference_type": "aws_bedrock",
+ "num_questions": 2,
+ "topics": ["python_basics"],
+ "is_demo": True
+ }
+ with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate:
+ mock_generate.return_value = {
+ "status": "completed",
+ "export_path": {"local": "test.json"},
+ "results": {"python_basics": [{"question": "test?", "solution": "test!"}]}
+ }
+ response = client.post("/synthesis/generate", json=request_data)
+ assert response.status_code == 200
+ res_json = response.json()
+ assert res_json.get("status") == "completed"
+ assert "export_path" in res_json
+
+def test_generate_endpoint_with_doc_paths():
+ request_data = {
+ "use_case": "custom",
+ "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ "inference_type": "aws_bedrock",
+ "num_questions": 2,
+ "doc_paths": ["test.pdf"],
+ "is_demo": True
+ }
+ with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate:
+ mock_generate.return_value = {
+ "status": "completed",
+ "export_path": {"local": "test.json"},
+ "results": {"chunk1": [{"question": "test?", "solution": "test!"}]}
+ }
+ response = client.post("/synthesis/generate", json=request_data)
+ assert response.status_code == 200
+ res_json = response.json()
+ assert res_json.get("status") == "completed"
+ assert "export_path" in res_json
+
+def test_generation_history():
+ # Patch db_manager.get_paginated_generate_metadata to return dummy metadata with pagination info
+ db_manager.get_paginated_generate_metadata_light = lambda page, page_size: (
+ 1, # total_count
+ [{"generate_file_name": "qa_pairs_claude_20250210T170521148_test.json",
+ "timestamp": "2024-02-10T12:00:00",
+ "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0"}]
+ )
+
+ # Since get_pending_generate_job_ids might be called, we should patch it too
+ db_manager.get_pending_generate_job_ids = lambda: []
+
+ response = client.get("/generations/history?page=1&page_size=10")
+ assert response.status_code == 200
+ res_json = response.json()
+
+ # Check that the response contains the expected structure
+ assert "data" in res_json
+ assert "pagination" in res_json
+
+ # Check pagination metadata
+ assert res_json["pagination"]["total"] == 1
+ assert res_json["pagination"]["page"] == 1
+ assert res_json["pagination"]["page_size"] == 10
+ assert res_json["pagination"]["total_pages"] == 1
+
+ # Check data contents
+ assert len(res_json["data"]) > 0
+ # Instead of expecting exactly "test.json", assert the filename contains "_test.json"
+ assert "_test.json" in res_json["data"][0]["generate_file_name"]
+
+def test_error_handling():
+ request_data = {
+ "use_case": "custom",
+ "model_id": "invalid.model",
+ "is_demo": True
+ }
+ response = client.post("/synthesis/generate", json=request_data)
+ # Expect an error with status code in [400,503] and key "error"
+ assert response.status_code in [400, 503]
+ res_json = response.json()
+ assert "error" in res_json
diff --git a/tests/unit/test_evaluator_freeform_service.py b/tests/unit/test_evaluator_freeform_service.py
new file mode 100644
index 00000000..38c8691d
--- /dev/null
+++ b/tests/unit/test_evaluator_freeform_service.py
@@ -0,0 +1,80 @@
+import pytest
+from unittest.mock import patch, Mock
+import json
+from app.services.evaluator_service import EvaluatorService
+from app.models.request_models import EvaluationRequest
+from tests.mocks.mock_db import MockDatabaseManager
+
+@pytest.fixture
+def mock_freeform_data():
+ return [{"field1": "value1", "field2": "value2", "field3": "value3"}]
+
+@pytest.fixture
+def mock_freeform_file(tmp_path, mock_freeform_data):
+ file_path = tmp_path / "freeform_data.json"
+ with open(file_path, "w") as f:
+ json.dump(mock_freeform_data, f)
+ return str(file_path)
+
+@pytest.fixture
+def evaluator_freeform_service():
+ service = EvaluatorService()
+ service.db = MockDatabaseManager()
+ return service
+
+def test_evaluate_row_data(evaluator_freeform_service, mock_freeform_file):
+ request = EvaluationRequest(
+ model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ use_case="custom",
+ import_path=mock_freeform_file,
+ is_demo=True,
+ output_key="field1",
+ output_value="field2"
+ )
+ with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good freeform data"}]
+ result = evaluator_freeform_service.evaluate_row_data(request)
+ assert result["status"] == "completed"
+ assert "output_path" in result
+ assert len(evaluator_freeform_service.db.evaluation_metadata) == 1
+
+def test_evaluate_single_row(evaluator_freeform_service):
+ with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ mock_response = [{"score": 4, "justification": "Good freeform row"}]
+ mock_handler.return_value.generate_response.return_value = mock_response
+
+ row = {"field1": "value1", "field2": "value2"}
+ request = EvaluationRequest(
+ use_case="custom",
+ model_id="test.model",
+ inference_type="aws_bedrock",
+ is_demo=True,
+ output_key="field1",
+ output_value="field2"
+ )
+ result = evaluator_freeform_service.evaluate_single_row(row, mock_handler.return_value, request)
+ assert result["evaluation"]["score"] == 4
+ assert "justification" in result["evaluation"]
+ assert result["row"] == row
+
+def test_evaluate_rows(evaluator_freeform_service):
+ rows = [
+ {"field1": "value1", "field2": "value2"},
+ {"field1": "value3", "field2": "value4"}
+ ]
+ request = EvaluationRequest(
+ use_case="custom",
+ model_id="test.model",
+ inference_type="aws_bedrock",
+ is_demo=True,
+ output_key="field1",
+ output_value="field2"
+ )
+
+ with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good row"}]
+ result = evaluator_freeform_service.evaluate_rows(rows, mock_handler.return_value, request)
+
+ assert result["total_evaluated"] == 2
+ assert result["average_score"] == 4
+ assert len(result["evaluated_rows"]) == 2
diff --git a/tests/unit/test_evaluator_legacy_service.py b/tests/unit/test_evaluator_legacy_service.py
new file mode 100644
index 00000000..c7dded81
--- /dev/null
+++ b/tests/unit/test_evaluator_legacy_service.py
@@ -0,0 +1,83 @@
+import pytest
+from io import StringIO
+from unittest.mock import patch
+import json
+from app.services.evaluator_legacy_service import EvaluatorLegacyService
+from app.models.request_models import EvaluationRequest
+from tests.mocks.mock_db import MockDatabaseManager
+from app.core.exceptions import ModelHandlerError, APIError
+
+@pytest.fixture
+def mock_qa_data():
+ return [{"question": "test question?", "solution": "test solution"}]
+
+@pytest.fixture
+def mock_qa_file(tmp_path, mock_qa_data):
+ file_path = tmp_path / "qa_pairs.json"
+ with open(file_path, "w") as f:
+ json.dump(mock_qa_data, f)
+ return str(file_path)
+
+@pytest.fixture
+def evaluator_service():
+ service = EvaluatorLegacyService()
+ service.db = MockDatabaseManager()
+ return service
+
+def test_evaluate_results(evaluator_service, mock_qa_file):
+ request = EvaluationRequest(
+ model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ use_case="custom",
+ import_path=mock_qa_file,
+ is_demo=True,
+ output_key="Prompt",
+ output_value="Completion"
+ )
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good answer"}]
+ result = evaluator_service.evaluate_results(request)
+ assert result["status"] == "completed"
+ assert "output_path" in result
+ assert len(evaluator_service.db.evaluation_metadata) == 1
+
+def test_evaluate_single_pair():
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
+ mock_response = [{"score": 4, "justification": "Good explanation"}]
+ mock_handler.return_value.generate_response.return_value = mock_response
+ service = EvaluatorLegacyService()
+ qa_pair = {"Prompt": "What is Python?", "Completion": "Python is a programming language"}
+ request = EvaluationRequest(
+ use_case="custom",
+ model_id="test.model",
+ inference_type="aws_bedrock",
+ is_demo=True,
+ output_key="Prompt",
+ output_value="Completion"
+ )
+ result = service.evaluate_single_pair(qa_pair, mock_handler.return_value, request)
+ assert result["evaluation"]["score"] == 4
+ assert "justification" in result["evaluation"]
+
+def test_evaluate_results_with_error():
+ fake_json = '[{"Seeds": "python_basics", "Prompt": "What is Python?", "Completion": "Python is a programming language"}]'
+ class DummyHandler:
+ def generate_response(self, prompt, **kwargs): # Accept any keyword arguments
+ raise ModelHandlerError("Test error")
+ with patch('app.services.evaluator_legacy_service.os.path.exists', return_value=True), \
+ patch('builtins.open', new=lambda f, mode, *args, **kwargs: StringIO(fake_json)), \
+ patch('app.services.evaluator_legacy_service.create_handler', return_value=DummyHandler()), \
+ patch('app.services.evaluator_legacy_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"):
+ service = EvaluatorLegacyService()
+ request = EvaluationRequest(
+ use_case="custom",
+ model_id="test.model",
+ inference_type="aws_bedrock",
+ import_path="test.json",
+ is_demo=True,
+ output_key="Prompt",
+ output_value="Completion",
+ caii_endpoint="dummy_endpoint",
+ display_name="dummy"
+ )
+ with pytest.raises(APIError, match="Test error"):
+ service.evaluate_results(request)
diff --git a/tests/unit/test_evaluator_service.py b/tests/unit/test_evaluator_service.py
index 629b80a3..c7dded81 100644
--- a/tests/unit/test_evaluator_service.py
+++ b/tests/unit/test_evaluator_service.py
@@ -2,7 +2,7 @@
from io import StringIO
from unittest.mock import patch
import json
-from app.services.evaluator_service import EvaluatorService
+from app.services.evaluator_legacy_service import EvaluatorLegacyService
from app.models.request_models import EvaluationRequest
from tests.mocks.mock_db import MockDatabaseManager
from app.core.exceptions import ModelHandlerError, APIError
@@ -20,7 +20,7 @@ def mock_qa_file(tmp_path, mock_qa_data):
@pytest.fixture
def evaluator_service():
- service = EvaluatorService()
+ service = EvaluatorLegacyService()
service.db = MockDatabaseManager()
return service
@@ -33,7 +33,7 @@ def test_evaluate_results(evaluator_service, mock_qa_file):
output_key="Prompt",
output_value="Completion"
)
- with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good answer"}]
result = evaluator_service.evaluate_results(request)
assert result["status"] == "completed"
@@ -41,10 +41,10 @@ def test_evaluate_results(evaluator_service, mock_qa_file):
assert len(evaluator_service.db.evaluation_metadata) == 1
def test_evaluate_single_pair():
- with patch('app.services.evaluator_service.create_handler') as mock_handler:
+ with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler:
mock_response = [{"score": 4, "justification": "Good explanation"}]
mock_handler.return_value.generate_response.return_value = mock_response
- service = EvaluatorService()
+ service = EvaluatorLegacyService()
qa_pair = {"Prompt": "What is Python?", "Completion": "Python is a programming language"}
request = EvaluationRequest(
use_case="custom",
@@ -63,11 +63,11 @@ def test_evaluate_results_with_error():
class DummyHandler:
def generate_response(self, prompt, **kwargs): # Accept any keyword arguments
raise ModelHandlerError("Test error")
- with patch('app.services.evaluator_service.os.path.exists', return_value=True), \
+ with patch('app.services.evaluator_legacy_service.os.path.exists', return_value=True), \
patch('builtins.open', new=lambda f, mode, *args, **kwargs: StringIO(fake_json)), \
- patch('app.services.evaluator_service.create_handler', return_value=DummyHandler()), \
- patch('app.services.evaluator_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"):
- service = EvaluatorService()
+ patch('app.services.evaluator_legacy_service.create_handler', return_value=DummyHandler()), \
+ patch('app.services.evaluator_legacy_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"):
+ service = EvaluatorLegacyService()
request = EvaluationRequest(
use_case="custom",
model_id="test.model",
diff --git a/tests/unit/test_synthesis_freeform_service.py b/tests/unit/test_synthesis_freeform_service.py
new file mode 100644
index 00000000..f0bcb861
--- /dev/null
+++ b/tests/unit/test_synthesis_freeform_service.py
@@ -0,0 +1,70 @@
+import pytest
+from unittest.mock import patch, Mock
+import json
+from app.services.synthesis_service import SynthesisService
+from app.models.request_models import SynthesisRequest
+from tests.mocks.mock_db import MockDatabaseManager
+
+@pytest.fixture
+def mock_json_data():
+ return [{"topic": "test_topic", "example_field": "test_value"}]
+
+@pytest.fixture
+def mock_file(tmp_path, mock_json_data):
+ file_path = tmp_path / "test.json"
+ with open(file_path, "w") as f:
+ json.dump(mock_json_data, f)
+ return str(file_path)
+
+@pytest.fixture
+def synthesis_freeform_service():
+ service = SynthesisService()
+ service.db = MockDatabaseManager()
+ return service
+
+@pytest.mark.asyncio
+async def test_generate_freeform_with_topics(synthesis_freeform_service):
+ request = SynthesisRequest(
+ model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ num_questions=2,
+ topics=["test_topic"],
+ is_demo=True,
+ use_case="custom",
+ technique="freeform"
+ )
+ with patch('app.services.synthesis_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"field1": "value1", "field2": "value2"}]
+ result = await synthesis_freeform_service.generate_freeform(request)
+ assert result["status"] == "completed"
+ assert len(synthesis_freeform_service.db.generation_metadata) == 1
+ assert synthesis_freeform_service.db.generation_metadata[0]["model_id"] == request.model_id
+
+@pytest.mark.asyncio
+async def test_generate_freeform_with_custom_examples(synthesis_freeform_service):
+ request = SynthesisRequest(
+ model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ num_questions=1,
+ topics=["test_topic"],
+ is_demo=True,
+ use_case="custom",
+ technique="freeform",
+ example_custom=[{"example_field": "example_value"}]
+ )
+ with patch('app.services.synthesis_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"generated_field": "generated_value"}]
+ result = await synthesis_freeform_service.generate_freeform(request)
+ assert result["status"] == "completed"
+ assert "export_path" in result
+
+def test_validate_freeform_item(synthesis_freeform_service):
+ # Valid freeform item
+ valid_item = {"field1": "value1", "field2": "value2"}
+ assert synthesis_freeform_service._validate_freeform_item(valid_item) == True
+
+ # Invalid freeform item (empty dict)
+ invalid_item = {}
+ assert synthesis_freeform_service._validate_freeform_item(invalid_item) == False
+
+ # Invalid freeform item (not a dict)
+ invalid_item = "not a dict"
+ assert synthesis_freeform_service._validate_freeform_item(invalid_item) == False
diff --git a/tests/unit/test_synthesis_legacy_service.py b/tests/unit/test_synthesis_legacy_service.py
new file mode 100644
index 00000000..120fad7c
--- /dev/null
+++ b/tests/unit/test_synthesis_legacy_service.py
@@ -0,0 +1,56 @@
+import pytest
+from unittest.mock import patch, Mock
+import json
+from app.services.synthesis_legacy_service import SynthesisLegacyService
+from app.models.request_models import SynthesisRequest
+from tests.mocks.mock_db import MockDatabaseManager
+
+@pytest.fixture
+def mock_json_data():
+ return [{"input": "test question?"}]
+
+@pytest.fixture
+def mock_file(tmp_path, mock_json_data):
+ file_path = tmp_path / "test.json"
+ with open(file_path, "w") as f:
+ json.dump(mock_json_data, f)
+ return str(file_path)
+
+@pytest.fixture
+def synthesis_service():
+ service = SynthesisLegacyService()
+ service.db = MockDatabaseManager()
+ return service
+
+@pytest.mark.asyncio
+async def test_generate_examples_with_topics(synthesis_service):
+ request = SynthesisRequest(
+ model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ num_questions=1,
+ topics=["test_topic"],
+ is_demo=True,
+ use_case="custom"
+ )
+ with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler:
+ mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}]
+ result = await synthesis_service.generate_examples(request)
+ assert result["status"] == "completed"
+ assert len(synthesis_service.db.generation_metadata) == 1
+ assert synthesis_service.db.generation_metadata[0]["model_id"] == request.model_id
+
+@pytest.mark.asyncio
+async def test_generate_examples_with_doc_paths(synthesis_service, mock_file):
+ request = SynthesisRequest(
+ model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
+ num_questions=1,
+ doc_paths=[mock_file],
+ is_demo=True,
+ use_case="custom"
+ )
+ with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler, \
+ patch('app.services.synthesis_legacy_service.DocumentProcessor') as mock_processor:
+ mock_processor.return_value.process_document.return_value = ["chunk1"]
+ mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}]
+ result = await synthesis_service.generate_examples(request)
+ assert result["status"] == "completed"
+ assert len(synthesis_service.db.generation_metadata) == 1
diff --git a/tests/unit/test_synthesis_service.py b/tests/unit/test_synthesis_service.py
index 6e9ca410..120fad7c 100644
--- a/tests/unit/test_synthesis_service.py
+++ b/tests/unit/test_synthesis_service.py
@@ -1,7 +1,7 @@
import pytest
from unittest.mock import patch, Mock
import json
-from app.services.synthesis_service import SynthesisService
+from app.services.synthesis_legacy_service import SynthesisLegacyService
from app.models.request_models import SynthesisRequest
from tests.mocks.mock_db import MockDatabaseManager
@@ -18,7 +18,7 @@ def mock_file(tmp_path, mock_json_data):
@pytest.fixture
def synthesis_service():
- service = SynthesisService()
+ service = SynthesisLegacyService()
service.db = MockDatabaseManager()
return service
@@ -31,7 +31,7 @@ async def test_generate_examples_with_topics(synthesis_service):
is_demo=True,
use_case="custom"
)
- with patch('app.services.synthesis_service.create_handler') as mock_handler:
+ with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler:
mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}]
result = await synthesis_service.generate_examples(request)
assert result["status"] == "completed"
@@ -47,8 +47,8 @@ async def test_generate_examples_with_doc_paths(synthesis_service, mock_file):
is_demo=True,
use_case="custom"
)
- with patch('app.services.synthesis_service.create_handler') as mock_handler, \
- patch('app.services.synthesis_service.DocumentProcessor') as mock_processor:
+ with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler, \
+ patch('app.services.synthesis_legacy_service.DocumentProcessor') as mock_processor:
mock_processor.return_value.process_document.return_value = ["chunk1"]
mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}]
result = await synthesis_service.generate_examples(request)