|
1 | 1 | import logging |
2 | 2 | import json |
3 | 3 | import re |
4 | | -import google.generativeai as genai |
5 | 4 | from typing import Dict, Any, List, Optional |
6 | 5 | from dataclasses import dataclass |
7 | 6 | from neo4j import GraphDatabase |
8 | 7 | from langchain_core.prompts import PromptTemplate |
9 | 8 | from ..dependencies import get_settings |
| 9 | +from ..utils.gemini_utils import GeminiClient |
10 | 10 |
|
11 | 11 | log = logging.getLogger(__name__) |
12 | 12 |
|
@@ -216,15 +216,7 @@ def __init__(self): |
216 | 216 | self.driver = GraphSearchService._driver |
217 | 217 |
|
218 | 218 | # Configure Gemini |
219 | | - if not settings.gemini_api_key: |
220 | | - raise ValueError("GEMINI_API_KEY not found") |
221 | | - genai.configure(api_key=settings.gemini_api_key) |
222 | | - |
223 | | - self.gemini_model_name = 'gemini-2.5-pro' |
224 | | - self.model = genai.GenerativeModel( |
225 | | - self.gemini_model_name, |
226 | | - generation_config={"temperature": 0.0}, |
227 | | - ) |
| 219 | + self.client = GeminiClient(generation_config={"temperature": 0.0}) |
228 | 220 |
|
229 | 221 | # Build schema snapshot if not exists |
230 | 222 | if GraphSearchService._schema_snapshot is None: |
@@ -363,11 +355,7 @@ def run_direct_graph_query( |
363 | 355 | """ |
364 | 356 |
|
365 | 357 | try: |
366 | | - model = genai.GenerativeModel( |
367 | | - self.gemini_model_name, |
368 | | - generation_config={"temperature": 0.0}, |
369 | | - ) |
370 | | - response = model.generate_content(prompt) |
| 358 | + response = self.client.generate_content(prompt) |
371 | 359 | cypher_query = self._clean(response.text) |
372 | 360 |
|
373 | 361 | log.info(f"Generated direct Cypher query: {cypher_query}...") |
@@ -537,7 +525,7 @@ def generate_cypher(self, question: str, node_context: Dict[str, Any]) -> Cypher |
537 | 525 | ) |
538 | 526 |
|
539 | 527 | try: |
540 | | - response = self.model.generate_content(prompt) |
| 528 | + response = self.client.generate_content(prompt) |
541 | 529 | cypher_query = self._clean(response.text) |
542 | 530 | reasoning = f"Graph traversal từ node {node_id}" |
543 | 531 |
|
@@ -693,11 +681,7 @@ def decide_context(self, question: str, node_text: str) -> Dict[str, str]: |
693 | 681 | "reason": "giải thích ngắn gọn, tiếng Việt" |
694 | 682 | }} |
695 | 683 | """ |
696 | | - model = genai.GenerativeModel( |
697 | | - self.gemini_model_name, |
698 | | - generation_config={"temperature": 0.1}, |
699 | | - ) |
700 | | - resp = model.generate_content(prompt) |
| 684 | + resp = self.client.generate_content(prompt) |
701 | 685 | try: |
702 | 686 | text = resp.text.strip() |
703 | 687 | if text.startswith("```"): |
@@ -781,11 +765,7 @@ def filter_relevant_nodes( |
781 | 765 | """ |
782 | 766 |
|
783 | 767 | try: |
784 | | - model = genai.GenerativeModel( |
785 | | - self.gemini_model_name, |
786 | | - generation_config={"temperature": 0.0}, |
787 | | - ) |
788 | | - response = model.generate_content(prompt) |
| 768 | + response = self.client.generate_content(prompt) |
789 | 769 | raw_text = response.text.strip() |
790 | 770 |
|
791 | 771 | if raw_text.startswith("```"): |
@@ -1143,7 +1123,7 @@ def run_graph_search_for_node_ids( |
1143 | 1123 |
|
1144 | 1124 | # Step 6: Final Gemini reranking to select best candidate |
1145 | 1125 | log.info("Step 6: Final Gemini reranking...") |
1146 | | - rerank_model = genai.GenerativeModel(self.gemini_model_name) |
| 1126 | + # rerank_model = genai.GenerativeModel(self.gemini_model_name) # Removed |
1147 | 1127 |
|
1148 | 1128 | numbered_candidates = [] |
1149 | 1129 | for idx, c in enumerate(candidates_for_llm, start=1): |
@@ -1187,7 +1167,7 @@ def run_graph_search_for_node_ids( |
1187 | 1167 | """ |
1188 | 1168 |
|
1189 | 1169 | try: |
1190 | | - response = rerank_model.generate_content(selection_prompt) |
| 1170 | + response = self.client.generate_content(selection_prompt) |
1191 | 1171 | raw_text = response.text.strip() |
1192 | 1172 | if raw_text.startswith("```"): |
1193 | 1173 | raw_text = self._clean(raw_text) |
|
0 commit comments