Skip to content

Commit 10e8347

Browse files
authored
Merge pull request #48 from DUT-Team-21TCLC-DT3/feat/tester_ai
Feat/tester ai
2 parents 4527fba + 86de256 commit 10e8347

15 files changed

+7030
-88
lines changed

ai_service/app/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ def base_url(self) -> str:
2828

2929
# LLMs
3030
gemini_api_key: Optional[str] = Field(None, alias="GEMINI_API_KEY")
31+
gemini_api_keys_str: Optional[str] = Field(None, alias="GEMINI_API_KEYS")
32+
33+
@property
34+
def gemini_api_keys(self) -> list[str]:
35+
"""
36+
Returns a list of API keys.
37+
Prioritizes GEMINI_API_KEYS (comma separated), falls back to GEMINI_API_KEY.
38+
"""
39+
if self.gemini_api_keys_str:
40+
return [k.strip() for k in self.gemini_api_keys_str.split(",") if k.strip()]
41+
if self.gemini_api_key:
42+
return [self.gemini_api_key]
43+
return []
44+
3145
cypher_model: str = Field("gpt-4o-mini", alias="CYPHER_MODEL")
3246
qa_model: str = Field("gpt-4o-mini", alias="QA_MODEL")
3347
cypher_temperature: float = Field(0.0, alias="CYPHER_TEMPERATURE")

ai_service/app/gemini_constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from app.dependencies import get_settings
2+
3+
settings = get_settings()
4+
5+
GEMINI_MODEL = "gemini-2.5-flash"
6+
GEMINI_API_KEYS = settings.gemini_api_keys

ai_service/app/pipelines/answer_composer.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import logging
2-
import google.generativeai as genai
32
from typing import Dict, Any, List
43
from ..dependencies import get_settings
54
from ..utils.stream_utils import create_metadata_chunk
5+
from ..utils.gemini_utils import GeminiClient
66

77
log = logging.getLogger(__name__)
88

99
class AnswerComposerService:
1010
def __init__(self):
1111
settings = get_settings()
12-
if not settings.gemini_api_key:
13-
raise ValueError("GEMINI_API_KEY not found")
14-
genai.configure(api_key=settings.gemini_api_key)
15-
self.model_name = 'gemini-2.5-pro'
12+
self.client = GeminiClient(generation_config={"temperature": 0.2})
1613
self.base_url = settings.base_url # Get dynamic base URL
1714

1815
def _build_citation_from_node_id(self, node_id: str) -> str:
@@ -191,13 +188,8 @@ def compose(
191188
192189
HÃY VIẾT CÂU TRẢ LỜI:
193190
"""
194-
model = genai.GenerativeModel(
195-
self.model_name,
196-
generation_config={"temperature": 0.2} # Hơi sáng tạo một chút để viết văn mượt mà
197-
)
198-
199191
try:
200-
response = model.generate_content(prompt)
192+
response = self.client.generate_content(prompt)
201193
return response.text.strip()
202194
except Exception as e:
203195
log.error(f"Lỗi compose answer: {e}")
@@ -349,14 +341,9 @@ def compose_stream(
349341
350342
HÃY VIẾT CÂU TRẢ LỜI:
351343
"""
352-
model = genai.GenerativeModel(
353-
self.model_name,
354-
generation_config={"temperature": 0.2}
355-
)
356-
357344
try:
358345
# Stream content directly, LLM includes citations in answer
359-
response = model.generate_content(prompt, stream=True)
346+
response = self.client.generate_content_stream(prompt)
360347
for chunk in response:
361348
if chunk.text:
362349
yield chunk.text

ai_service/app/pipelines/graph_search.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
22
import json
33
import re
4-
import google.generativeai as genai
54
from typing import Dict, Any, List, Optional
65
from dataclasses import dataclass
76
from neo4j import GraphDatabase
87
from langchain_core.prompts import PromptTemplate
98
from ..dependencies import get_settings
9+
from ..utils.gemini_utils import GeminiClient
1010

1111
log = logging.getLogger(__name__)
1212

@@ -216,15 +216,7 @@ def __init__(self):
216216
self.driver = GraphSearchService._driver
217217

218218
# Configure Gemini
219-
if not settings.gemini_api_key:
220-
raise ValueError("GEMINI_API_KEY not found")
221-
genai.configure(api_key=settings.gemini_api_key)
222-
223-
self.gemini_model_name = 'gemini-2.5-pro'
224-
self.model = genai.GenerativeModel(
225-
self.gemini_model_name,
226-
generation_config={"temperature": 0.0},
227-
)
219+
self.client = GeminiClient(generation_config={"temperature": 0.0})
228220

229221
# Build schema snapshot if not exists
230222
if GraphSearchService._schema_snapshot is None:
@@ -363,11 +355,7 @@ def run_direct_graph_query(
363355
"""
364356

365357
try:
366-
model = genai.GenerativeModel(
367-
self.gemini_model_name,
368-
generation_config={"temperature": 0.0},
369-
)
370-
response = model.generate_content(prompt)
358+
response = self.client.generate_content(prompt)
371359
cypher_query = self._clean(response.text)
372360

373361
log.info(f"Generated direct Cypher query: {cypher_query}...")
@@ -537,7 +525,7 @@ def generate_cypher(self, question: str, node_context: Dict[str, Any]) -> Cypher
537525
)
538526

539527
try:
540-
response = self.model.generate_content(prompt)
528+
response = self.client.generate_content(prompt)
541529
cypher_query = self._clean(response.text)
542530
reasoning = f"Graph traversal từ node {node_id}"
543531

@@ -693,11 +681,7 @@ def decide_context(self, question: str, node_text: str) -> Dict[str, str]:
693681
"reason": "giải thích ngắn gọn, tiếng Việt"
694682
}}
695683
"""
696-
model = genai.GenerativeModel(
697-
self.gemini_model_name,
698-
generation_config={"temperature": 0.1},
699-
)
700-
resp = model.generate_content(prompt)
684+
resp = self.client.generate_content(prompt)
701685
try:
702686
text = resp.text.strip()
703687
if text.startswith("```"):
@@ -781,11 +765,7 @@ def filter_relevant_nodes(
781765
"""
782766

783767
try:
784-
model = genai.GenerativeModel(
785-
self.gemini_model_name,
786-
generation_config={"temperature": 0.0},
787-
)
788-
response = model.generate_content(prompt)
768+
response = self.client.generate_content(prompt)
789769
raw_text = response.text.strip()
790770

791771
if raw_text.startswith("```"):
@@ -1143,7 +1123,7 @@ def run_graph_search_for_node_ids(
11431123

11441124
# Step 6: Final Gemini reranking to select best candidate
11451125
log.info("Step 6: Final Gemini reranking...")
1146-
rerank_model = genai.GenerativeModel(self.gemini_model_name)
1126+
# rerank_model = genai.GenerativeModel(self.gemini_model_name) # Removed
11471127

11481128
numbered_candidates = []
11491129
for idx, c in enumerate(candidates_for_llm, start=1):
@@ -1187,7 +1167,7 @@ def run_graph_search_for_node_ids(
11871167
"""
11881168

11891169
try:
1190-
response = rerank_model.generate_content(selection_prompt)
1170+
response = self.client.generate_content(selection_prompt)
11911171
raw_text = response.text.strip()
11921172
if raw_text.startswith("```"):
11931173
raw_text = self._clean(raw_text)

ai_service/app/pipelines/query_processor.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import logging
22
import json
3-
import google.generativeai as genai
43
from typing import Dict, Any
54
from ..dependencies import get_settings
5+
from ..utils.gemini_utils import GeminiClient
66

77
log = logging.getLogger(__name__)
88

@@ -129,28 +129,19 @@
129129

130130
class QueryPreprocessor:
131131
def __init__(self):
132-
settings = get_settings()
133-
api_key = settings.gemini_api_key
134-
if not api_key:
135-
raise ValueError("GEMINI_API_KEY not found in settings")
136-
137-
genai.configure(api_key=api_key)
138-
139132
# Cấu hình Model
140133
generation_config = {
141134
"response_mime_type": "application/json",
142135
"temperature": 0.0,
143136
}
144137

145138
# Khởi tạo model với System Instruction riêng biệt
146-
gemini_model = 'gemini-2.5-flash'
147-
self.model = genai.GenerativeModel(
148-
gemini_model,
149-
generation_config=generation_config,
150-
system_instruction=SYSTEM_INSTRUCTION
139+
self.client = GeminiClient(
140+
system_instruction=SYSTEM_INSTRUCTION,
141+
generation_config=generation_config
151142
)
152143

153-
log.info(f"QueryRewriter initialized successfully ({gemini_model})")
144+
log.info(f"QueryRewriter initialized successfully")
154145

155146
def _clean_json_string(self, text: str) -> str:
156147
"""Làm sạch chuỗi JSON nếu model trả về markdown"""
@@ -172,7 +163,7 @@ def rewrite(self, question: str) -> Dict[str, Any]:
172163
try:
173164
user_prompt = f"Phân tích câu hỏi sau: \"{question}\""
174165

175-
response = self.model.generate_content(user_prompt)
166+
response = self.client.generate_content(user_prompt)
176167
json_text = self._clean_json_string(response.text)
177168

178169
try:

ai_service/app/pipelines/reranker.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
import logging
22
import json
3-
import google.generativeai as genai
43
from typing import List, Dict, Any
54
from ..dependencies import get_settings
5+
from ..utils.gemini_utils import GeminiClient
66

77
log = logging.getLogger(__name__)
88

99
class RerankerService:
1010
def __init__(self):
11-
settings = get_settings()
12-
if not settings.gemini_api_key:
13-
raise ValueError("GEMINI_API_KEY not found")
14-
genai.configure(api_key=settings.gemini_api_key)
15-
self.model_name = 'gemini-2.5-pro' # Or settings.GEMINI_RERANK_MODEL
11+
# settings = get_settings()
12+
# Initializing GeminiClient for potential future use (reranking currently disabled)
13+
self.client = GeminiClient()
1614

1715
def _clean(self, text: str) -> str:
1816
cleaned = text.strip()

ai_service/app/pipelines/social_chat.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,19 @@
11
import logging
2-
import google.generativeai as genai
32
from ..dependencies import get_settings
3+
from ..utils.gemini_utils import GeminiClient
44

55
log = logging.getLogger(__name__)
66

77
class SocialChatService:
88
def __init__(self):
9-
settings = get_settings()
10-
self.api_key = settings.gemini_api_key
11-
self.model_name = 'gemini-2.5-flash'
12-
13-
if not self.api_key:
14-
log.warning("⚠️ GEMINI_API_KEY not found. Social Chat will fail.")
15-
self.client = None
16-
else:
17-
try:
18-
genai.configure(api_key=self.api_key)
19-
self.client = genai.GenerativeModel(self.model_name)
20-
log.info("Social Chat Service (Gemini Client) initialized.")
21-
except Exception as e:
22-
log.error(f"Failed to init Gemini Client: {e}")
23-
self.client = None
9+
# settings = get_settings()
10+
try:
11+
# Initialize GeminiClient with higher temperature for social chat
12+
self.client = GeminiClient(generation_config={"temperature": 0.7})
13+
log.info("Social Chat Service (Gemini Client) initialized.")
14+
except Exception as e:
15+
log.error(f"Failed to init Gemini Client: {e}")
16+
self.client = None
2417

2518
def chat(self, intent: str, question: str):
2619
"""
@@ -55,12 +48,7 @@ def chat(self, intent: str, question: str):
5548
5649
HÃY VIẾT CÂU TRẢ LỜI SÁNG TẠO VÀ PHÙ HỢP NHẤT:
5750
"""
58-
model = genai.GenerativeModel(
59-
self.model_name,
60-
generation_config={"temperature": 0.7} # Hơi sáng tạo để trả lời tự nhiên hơn
61-
)
62-
63-
response = model.generate_content(prompt, stream=True)
51+
response = self.client.generate_content_stream(prompt)
6452
for chunk in response:
6553
if chunk.text:
6654
yield chunk.text

0 commit comments

Comments
 (0)