|
9 | 9 | 4. Cluster summarization |
10 | 10 | """ |
11 | 11 | import logging |
12 | | -import os |
13 | 12 | import random |
14 | 13 | from typing import Dict, List, Optional, Tuple |
15 | 14 |
|
16 | 15 | import numpy as np |
17 | | -import yaml |
18 | 16 | from jinja2 import Template, StrictUndefined |
19 | 17 | from nexent.vector_database.base import VectorDatabaseCore |
20 | 18 | from sklearn.cluster import KMeans |
21 | 19 | from sklearn.metrics import silhouette_score |
22 | 20 | from sklearn.metrics.pairwise import cosine_similarity |
23 | 21 |
|
24 | 22 | from consts.const import LANGUAGE |
| 23 | +from utils.prompt_template_utils import ( |
| 24 | + get_document_summary_prompt_template, |
| 25 | + get_cluster_summary_reduce_prompt_template, |
| 26 | + get_cluster_summary_agent_prompt_template |
| 27 | +) |
25 | 28 |
|
26 | 29 | logger = logging.getLogger("document_vector_utils") |
27 | 30 |
|
28 | 31 |
|
29 | | -def _get_prompt_absolute_path(relative_path: str) -> str: |
30 | | - """ |
31 | | - Get absolute path for prompt files. |
32 | | - |
33 | | - Args: |
34 | | - relative_path: Relative path like 'backend/prompts/xxx.yaml' |
35 | | - |
36 | | - Returns: |
37 | | - Absolute path to the prompt file |
38 | | - """ |
39 | | - # Get the directory of this file and construct absolute path |
40 | | - current_dir = os.path.dirname(os.path.abspath(__file__)) |
41 | | - # Go up one level from utils to backend, then use the template path |
42 | | - backend_dir = os.path.dirname(current_dir) |
43 | | - absolute_path = os.path.join(backend_dir, relative_path.replace('backend/', '')) |
44 | | - return absolute_path |
45 | | - |
46 | | - |
47 | 32 | def get_documents_from_es(index_name: str, vdb_core: VectorDatabaseCore, sample_doc_count: int = 200) -> Dict[str, Dict]: |
48 | 33 | """ |
49 | 34 | Get document samples from Elasticsearch, aggregated by path_or_url |
@@ -567,14 +552,8 @@ def summarize_document(document_content: str, filename: str, language: str = LAN |
567 | 552 | Document summary text |
568 | 553 | """ |
569 | 554 | try: |
570 | | - # Select prompt file based on language |
571 | | - if language == LANGUAGE["ZH"]: |
572 | | - prompt_path = _get_prompt_absolute_path('backend/prompts/document_summary_agent_zh.yaml') |
573 | | - else: |
574 | | - prompt_path = _get_prompt_absolute_path('backend/prompts/document_summary_agent.yaml') |
575 | | - |
576 | | - with open(prompt_path, 'r', encoding='utf-8') as f: |
577 | | - prompts = yaml.safe_load(f) |
| 555 | + # Get prompt template from prompt_template_utils |
| 556 | + prompts = get_document_summary_prompt_template(language) |
578 | 557 |
|
579 | 558 | system_prompt = prompts.get('system_prompt', '') |
580 | 559 | user_prompt_template = prompts.get('user_prompt', '') |
@@ -645,14 +624,8 @@ def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["Z |
645 | 624 | Cluster summary text |
646 | 625 | """ |
647 | 626 | try: |
648 | | - # Select prompt file based on language |
649 | | - if language == LANGUAGE["ZH"]: |
650 | | - prompt_path = _get_prompt_absolute_path('backend/prompts/cluster_summary_reduce_zh.yaml') |
651 | | - else: |
652 | | - prompt_path = _get_prompt_absolute_path('backend/prompts/cluster_summary_reduce.yaml') |
653 | | - |
654 | | - with open(prompt_path, 'r', encoding='utf-8') as f: |
655 | | - prompts = yaml.safe_load(f) |
| 627 | + # Get prompt template from prompt_template_utils |
| 628 | + prompts = get_cluster_summary_reduce_prompt_template(language) |
656 | 629 |
|
657 | 630 | system_prompt = prompts.get('system_prompt', '') |
658 | 631 | user_prompt_template = prompts.get('user_prompt', '') |
@@ -957,9 +930,8 @@ def summarize_cluster_legacy(cluster_content: str, language: str = LANGUAGE["ZH" |
957 | 930 | Cluster summary text |
958 | 931 | """ |
959 | 932 | try: |
960 | | - prompt_path = _get_prompt_absolute_path('backend/prompts/cluster_summary_agent.yaml') |
961 | | - with open(prompt_path, 'r', encoding='utf-8') as f: |
962 | | - prompts = yaml.safe_load(f) |
| 933 | + # Get prompt template from prompt_template_utils |
| 934 | + prompts = get_cluster_summary_agent_prompt_template(language) |
963 | 935 |
|
964 | 936 | system_prompt = prompts.get('system_prompt', '') |
965 | 937 | user_prompt_template = prompts.get('user_prompt', '') |
|
0 commit comments