diff --git a/backend/app/services/embedding_service/service.py b/backend/app/services/embedding_service/service.py index 4527c4f9..40ba3c05 100644 --- a/backend/app/services/embedding_service/service.py +++ b/backend/app/services/embedding_service/service.py @@ -1,225 +1,329 @@ import logging -import config +import threading +import asyncio +import os from typing import List, Dict, Any, Optional +from concurrent.futures import ThreadPoolExecutor + import torch from pydantic import BaseModel from sentence_transformers import SentenceTransformer from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import HumanMessage + +import config from app.core.config import settings from app.models.database.weaviate import WeaviateUserProfile -from app.services.embedding_service.profile_summarization.prompts.summarization_prompt import PROFILE_SUMMARIZATION_PROMPT +from app.services.embedding_service.profile_summarization.prompts.summarization_prompt import ( + PROFILE_SUMMARIZATION_PROMPT, +) +try: + import tiktoken + TIKTOKEN_AVAILABLE = True +except ImportError: + TIKTOKEN_AVAILABLE = False + +logger = logging.getLogger(__name__) MODEL_NAME = config.MODEL_NAME -MAX_BATCH_SIZE = config.MAX_BATCH_SIZE EMBEDDING_DEVICE = config.EMBEDDING_DEVICE - - -logger = logging.getLogger(__name__) +MAX_BATCH_SIZE = config.MAX_BATCH_SIZE +SAFE_BATCH_SIZE = getattr(config, "SAFE_BATCH_SIZE", 32) +EXECUTOR_MAX_WORKERS = getattr( + config, "EXECUTOR_MAX_WORKERS", min(2, os.cpu_count() or 1) +) +DEFAULT_MAX_CONCURRENT_GPU_TASKS = getattr( + config, "MAX_CONCURRENT_GPU_TASKS", 2 +) class ProfileSummaryResult(BaseModel): - """Result of profile summarization""" summary_text: str token_count_estimate: int - embedding: Optional[List[float]] = None + embedding: List[float] -class EmbeddingService: - """Service for generating embeddings and profile summarization for Weaviate integration""" - def __init__(self, model_name: str = MODEL_NAME, device: str = EMBEDDING_DEVICE): - """Initialize the embedding service with specified model and LLM""" - self.model_name = model_name - self.device = device - self._model = None - self._llm = None - logger.info(f"Initializing EmbeddingService with model: {model_name} on device: {device}") +class EmbeddingService: + # Class-level resources shared across all instances + _global_model: Optional[SentenceTransformer] = None + _global_model_lock = threading.Lock() + _shutting_down_global = False # Global shutdown flag for shared model - @property - def model(self) -> SentenceTransformer: - """Lazy-load embedding model to avoid loading during import""" - if self._model is None: - try: - logger.info(f"Loading embedding model: {self.model_name}") - self._model = SentenceTransformer(self.model_name, device=self.device) - logger.info( - f"Model loaded successfully. Embedding dimension: {self._model.get_sentence_embedding_dimension()}") - except Exception as e: - logger.error(f"Error loading model {self.model_name}: {str(e)}") - raise - return self._model + def __init__(self) -> None: + self._llm: Optional[ChatGoogleGenerativeAI] = None + self._tokenizer: Optional[Any] = None - @property - def llm(self) -> ChatGoogleGenerativeAI: - """Lazy-load LLM for profile summarization""" - if self._llm is None: - try: - self._llm = ChatGoogleGenerativeAI( - model=settings.github_agent_model, - temperature=0.3, - google_api_key=settings.gemini_api_key - ) - logger.info("LLM initialized for profile summarization") - except Exception as e: - logger.error(f"Error initializing LLM: {str(e)}") - raise - return self._llm + self._embedding_executor: Optional[ThreadPoolExecutor] = None + self._llm_executor: Optional[ThreadPoolExecutor] = None - async def get_embedding(self, text: str) -> List[float]: - """Generate embedding for a single text input""" - try: - # Convert to list for consistency - if isinstance(text, str): - text = [text] - - # Generate embeddings - embeddings = self.model.encode( - text, - convert_to_tensor=True, - show_progress_bar=False - ) + self._executor_lock = threading.Lock() + self._llm_lock = threading.Lock() + self._tokenizer_lock = threading.Lock() - # Convert to standard Python list and return - embedding_list = embeddings[0].cpu().tolist() - logger.debug(f"Generated embedding with dimension: {len(embedding_list)}") - return embedding_list - except Exception as e: - logger.error(f"Error generating embedding: {str(e)}") - raise + self._gpu_semaphores: Dict[int, asyncio.Semaphore] = {} + self._gpu_semaphore_lock = threading.Lock() - async def get_embeddings(self, texts: List[str]) -> List[List[float]]: - """Generate embeddings for multiple text inputs in batches""" - try: - # Generate embeddings - embeddings = self.model.encode( - texts, - convert_to_tensor=True, - batch_size=MAX_BATCH_SIZE, - show_progress_bar=len(texts) > 10 - ) + self._shutting_down = False - # Convert to standard Python list - embedding_list = embeddings.cpu().tolist() - logger.info(f"Generated {len(embedding_list)} embeddings") - return embedding_list - except Exception as e: - logger.error(f"Error generating batch embeddings: {str(e)}") - raise + logger.info( + "EmbeddingService initialized | device=%s | workers=%s | gpu_limit=%s", + EMBEDDING_DEVICE, + EXECUTOR_MAX_WORKERS, + DEFAULT_MAX_CONCURRENT_GPU_TASKS, + ) - async def summarize_user_profile(self, profile: WeaviateUserProfile) -> ProfileSummaryResult: - """Generate a comprehensive summary of a user profile optimized for embedding and semantic search.""" - try: - logger.info(f"Summarizing profile for user: {profile.github_username}") - - bio = profile.bio or "No bio provided" - languages = ", ".join(profile.languages) if profile.languages else "No languages specified" - topics = ", ".join(profile.topics) if profile.topics else "No topics specified" - - prs_info = [] - for pr in profile.pull_requests: - pr_desc = pr.body if pr.body else "No description" - prs_info.append(f"{pr.title} in {pr.repository}: {pr_desc}") - pull_requests_text = " | ".join(prs_info) if prs_info else "No recent pull requests" - - stats_text = f"Followers: {profile.followers_count}, Following: {profile.following_count}, Total Stars: {profile.total_stars_received}, Total Forks: {profile.total_forks}" - - prompt = PROFILE_SUMMARIZATION_PROMPT.format( - github_username=profile.github_username, - bio=bio, - languages=languages, - pull_requests=pull_requests_text, - topics=topics, - stats=stats_text - ) + def _get_gpu_semaphore(self, limit: Optional[int]) -> asyncio.Semaphore: + concurrency = limit or DEFAULT_MAX_CONCURRENT_GPU_TASKS - logger.info(f"Sending profile summarization request to LLM for {profile.github_username}") - response = await self.llm.ainvoke([HumanMessage(content=prompt)]) - summary_text = response.content.strip() + with self._gpu_semaphore_lock: + if concurrency not in self._gpu_semaphores: + self._gpu_semaphores[concurrency] = asyncio.Semaphore(concurrency) + return self._gpu_semaphores[concurrency] - # Estimate token count (rough approximation: 1 token ≈ 4 characters) - token_estimate = len(summary_text) // 4 - logger.info( - f"Generated profile summary for {profile.github_username}: {len(summary_text)} chars (~{token_estimate} tokens)" - ) + @property + def embedding_executor(self) -> ThreadPoolExecutor: + if self._embedding_executor is None: + with self._executor_lock: + if self._embedding_executor is None: + self._embedding_executor = ThreadPoolExecutor( + max_workers=EXECUTOR_MAX_WORKERS, + thread_name_prefix="embedding-worker", + ) + return self._embedding_executor - embedding = await self.get_embedding(summary_text) + @property + def llm_executor(self) -> ThreadPoolExecutor: + if self._llm_executor is None: + with self._executor_lock: + if self._llm_executor is None: + self._llm_executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="llm-worker", + ) + return self._llm_executor - return ProfileSummaryResult( - summary_text=summary_text, - token_count_estimate=token_estimate, - embedding=embedding - ) + @property + def model(self) -> SentenceTransformer: + # First check instance shutdown flag + if self._shutting_down: + raise RuntimeError("This EmbeddingService instance is shutting down") + + # Then check and load model with proper locking + if EmbeddingService._global_model is None: + with EmbeddingService._global_model_lock: + # Check global shutdown flag inside the lock + if EmbeddingService._shutting_down_global: + raise RuntimeError("EmbeddingService globally is shutting down") + + if EmbeddingService._global_model is None: + logger.info("Loading embedding model: %s", MODEL_NAME) + EmbeddingService._global_model = SentenceTransformer( + MODEL_NAME, + device=EMBEDDING_DEVICE, + ) + return EmbeddingService._global_model - except Exception as e: - logger.error(f"Error summarizing profile for {profile.github_username}: {str(e)}") - raise + @property + def llm(self) -> ChatGoogleGenerativeAI: + if self._llm is None: + with self._llm_lock: + if self._llm is None: + self._llm = ChatGoogleGenerativeAI( + model=settings.github_agent_model, + temperature=0.3, + google_api_key=settings.gemini_api_key, + ) + return self._llm - async def process_user_profile(self, profile: WeaviateUserProfile) -> tuple[WeaviateUserProfile, List[float]]: - """Process a user profile by generating summary and embedding, then updating the profile object.""" - try: - logger.info(f"Processing user profile for Weaviate storage: {profile.github_username}") + @property + def tokenizer(self) -> Optional[Any]: + if not TIKTOKEN_AVAILABLE: + return None + + if self._tokenizer is None: + with self._tokenizer_lock: + if self._tokenizer is None: + self._tokenizer = tiktoken.get_encoding("cl100k_base") + return self._tokenizer + + def _count_tokens(self, text: str) -> int: + if self.tokenizer: + try: + return len(self.tokenizer.encode(text)) + except Exception as e: + logger.warning(f"Tokenizer failed: {e}, falling back to heuristic") + # Fallback estimation based on words + return max(1, int(len(text.split()) * 1.3)) + + async def _encode( + self, + texts: List[str], + max_concurrent_tasks: Optional[int], + ) -> torch.Tensor: + if not texts: + embedding_dim = self.model.get_sentence_embedding_dimension() + return torch.empty((0, embedding_dim)) + + semaphore = self._get_gpu_semaphore(max_concurrent_tasks) + + async with semaphore: + loop = asyncio.get_running_loop() + outputs: List[torch.Tensor] = [] + + for i in range(0, len(texts), MAX_BATCH_SIZE): + batch = texts[i : i + MAX_BATCH_SIZE] + + tensor = await loop.run_in_executor( + self.embedding_executor, + lambda b=batch: self.model.encode( + b, + convert_to_tensor=True, + normalize_embeddings=True, + batch_size=min(len(b), SAFE_BATCH_SIZE), + ), + ) - summary_result = await self.summarize_user_profile(profile) + outputs.append(tensor) + await asyncio.sleep(0) - profile.profile_text_for_embedding = summary_result.summary_text + return torch.cat(outputs, dim=0) - logger.info( - f"Successfully processed profile for {profile.github_username}: summary generated with {summary_result.token_count_estimate} estimated tokens" - ) + async def get_embedding( + self, + text: str, + max_concurrent_tasks: Optional[int] = None, + ) -> List[float]: + tensor = await self._encode([text], max_concurrent_tasks) + return tensor[0].cpu().tolist() - return profile, summary_result.embedding + async def get_embeddings( + self, + texts: List[str], + max_concurrent_tasks: Optional[int] = None, + ) -> List[List[float]]: + tensor = await self._encode(texts, max_concurrent_tasks) + return tensor.cpu().tolist() - except Exception as e: - logger.error(f"Error processing user profile for Weaviate: {str(e)}") - raise - - async def search_similar_profiles(self, query_text: str, limit: int = 10) -> List[Dict[str, Any]]: - """ - Search for similar profiles using embedding similarity. - This method generates an embedding for the query and searches for similar contributors. - """ + async def _invoke_llm(self, messages: List[HumanMessage]) -> str: try: - logger.info(f"Searching for similar profiles with query: {query_text[:100]}") - - query_embedding = await self.get_embedding(query_text) - - logger.info(f"Generated query embedding with dimension: {len(query_embedding)}") - - from app.database.weaviate.operations import search_similar_contributors - - results = await search_similar_contributors( - query_embedding=query_embedding, - limit=limit, - min_distance=0.5 + response = await self.llm.ainvoke(messages) + return response.content.strip() + except Exception as e: + logger.exception( + "LLM invocation failed for profile=%s", + getattr(messages[0], "content", "")[:50], ) - logger.info(f"Found {len(results)} similar contributors for query") - return results - - except Exception as e: - logger.error(f"Error searching similar profiles: {str(e)}") - raise + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + self.llm_executor, + lambda: self.llm.invoke(messages), + ) + return response.content.strip() + + async def summarize_user_profile( + self, + profile: WeaviateUserProfile, + max_concurrent_tasks: Optional[int] = None, + ) -> ProfileSummaryResult: + prompt = PROFILE_SUMMARIZATION_PROMPT.format( + github_username=profile.github_username, + bio=profile.bio or "No bio", + languages=", ".join(profile.languages or []), + topics=", ".join(profile.topics or []), + pull_requests=" | ".join( + f"{pr.title}: {pr.body or ''}" for pr in profile.pull_requests + ) or "No PRs", + stats=f"Followers={profile.followers_count}, Stars={profile.total_stars_received}", + ) + + summary = await self._invoke_llm([HumanMessage(content=prompt)]) + token_count = self._count_tokens(summary) + embedding = await self.get_embedding(summary, max_concurrent_tasks) + + return ProfileSummaryResult( + summary_text=summary, + token_count_estimate=token_count, + embedding=embedding, + ) + + async def process_user_profile( + self, + profile: WeaviateUserProfile, + max_concurrent_tasks: Optional[int] = None, + ) -> tuple[WeaviateUserProfile, List[float]]: + summary_result = await self.summarize_user_profile(profile, max_concurrent_tasks) + profile.profile_text_for_embedding = summary_result.summary_text + return profile, summary_result.embedding + + async def search_similar_profiles( + self, + query_text: str, + limit: int = 10, + max_concurrent_tasks: Optional[int] = None, + ) -> List[Dict[str, Any]]: + query_embedding = await self.get_embedding(query_text, max_concurrent_tasks) + from app.database.weaviate.operations import search_similar_contributors + + return await search_similar_contributors( + query_embedding=query_embedding, + limit=limit, + min_distance=0.5, + ) def get_model_info(self) -> Dict[str, Any]: - """Get information about the model being used""" return { - "model_name": self.model_name, - "device": self.device, + "model_name": MODEL_NAME, + "device": EMBEDDING_DEVICE, "embedding_size": self.model.get_sentence_embedding_dimension(), + "tiktoken_available": TIKTOKEN_AVAILABLE, + "safe_batch_size": SAFE_BATCH_SIZE, + "max_batch_size": MAX_BATCH_SIZE, + "default_max_concurrent_gpu_tasks": DEFAULT_MAX_CONCURRENT_GPU_TASKS, + "executor_workers": EXECUTOR_MAX_WORKERS, + "model_loaded": EmbeddingService._global_model is not None, + "global_shutdown": EmbeddingService._shutting_down_global, + "instance_shutdown": self._shutting_down, } - def clear_cache(self): - """Clear the model cache to free memory""" - if self._model: - del self._model - self._model = None - if self._llm: - del self._llm + def shutdown(self) -> None: + """Shutdown this instance and optionally the global model if no other instances exist.""" + self._shutting_down = True + logger.info("Shutting down EmbeddingService instance") + + with self._executor_lock: + if self._embedding_executor: + self._embedding_executor.shutdown(wait=True) + self._embedding_executor = None + if self._llm_executor: + self._llm_executor.shutdown(wait=True) + self._llm_executor = None + + # Clear instance-specific resources + with self._llm_lock: self._llm = None - # Force garbage collection - import gc - gc.collect() + + with self._tokenizer_lock: + self._tokenizer = None + + logger.info("EmbeddingService instance shutdown complete") + + @classmethod + def shutdown_global(cls) -> None: + """Shutdown all global resources (model) shared across all instances.""" + with cls._global_model_lock: + cls._shutting_down_global = True + cls._global_model = None + if torch.cuda.is_available(): torch.cuda.empty_cache() - logger.info("Cleared embedding service cache") + logger.info("Cleared GPU cache") + + logger.info("EmbeddingService global shutdown complete") + + async def __aenter__(self) -> "EmbeddingService": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + self.shutdown() \ No newline at end of file