-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcontext_manager.py
More file actions
520 lines (441 loc) · 22 KB
/
context_manager.py
File metadata and controls
520 lines (441 loc) · 22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
import json
import os
import logging
import re
from pathlib import Path
from typing import Dict, List, Optional
from sentence_transformers import SentenceTransformer
import torch
from config import PATHS, EMBEDDING_MODEL, KAG_CONFIG
# KAG solver 已移到 KAG 项目内部
logger = logging.getLogger(__name__)
class ContextManager:
def __init__(self):
self.static_context: Dict[str, str] = {}
self.last_kag_answer: str = ""
self.last_kag_input_query: str = ""
self.last_kag_tasks: List[Dict] = []
self._last_query: str = ""
self._last_context: List[Dict] = []
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {device} 进行embedding计算")
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL, device=device)
self._init_kag_solver()
self._load_static_context()
def _embed_query(self, query: str) -> List[float]:
"""对查询文本进行embedding,添加query前缀(BGE模型优化)"""
prefixed_query = f"query: {query}"
return self.embedding_model.encode(prefixed_query).tolist()
def _embed_passage(self, text: str) -> List[float]:
"""对文档文本进行embedding,添加passage前缀(BGE模型优化)"""
prefixed_text = f"passage: {text}"
return self.embedding_model.encode(prefixed_text).tolist()
def _init_kag_solver(self):
"""初始化KAG推理问答器(新版,基于KAG开发者模式)"""
try:
# 导入KAG项目内部的solver wrapper
import sys
from pathlib import Path
base_dir = Path(__file__).parent
kag_path = base_dir / "KAG"
kag_path_str = str(kag_path.resolve()) # 使用绝对路径
if kag_path_str not in sys.path:
sys.path.insert(0, kag_path_str)
logger.debug(f"已添加KAG路径到sys.path: {kag_path_str}")
# 验证KAG目录和kag模块是否存在
kag_module_path = kag_path / "kag"
if not kag_module_path.exists():
raise FileNotFoundError(f"KAG模块目录不存在: {kag_module_path}")
# 先测试kag模块是否可以导入
try:
import KAG.kag as kag
logger.debug(f"KAG模块导入成功,路径: {getattr(kag, '__file__', 'unknown')}")
except ImportError as e:
logger.error(f"无法导入kag模块: {e}")
logger.error(f"KAG路径: {kag_path_str}")
logger.error(f"KAG目录存在: {kag_path.exists()}")
logger.error(f"kag模块目录存在: {kag_module_path.exists()}")
logger.error(f"当前sys.path前5项: {sys.path[:5]}")
raise
# 导入KAG solver wrapper
from KAG.kag.examples.MilitaryDeployment.solver.kag_solver_wrapper import KAGSolverWrapper
# 创建solver实例
self.kag_solver = KAGSolverWrapper()
logger.info("KAG推理问答器初始化完成")
except Exception as e:
logger.warning(f"KAG推理器初始化失败: {e}", exc_info=True)
self.kag_solver = None
def _load_static_context(self):
static_file = PATHS["static_context_dir"] / "prompts.json"
if static_file.exists():
with open(static_file, "r", encoding="utf-8") as f:
self.static_context = json.load(f)
else:
raise FileNotFoundError(
f"提示词文件不存在: {static_file}\n"
"请创建 context/static/prompts.json 文件,包含 plan_prompt, work_first_think_prompt, work_second_think_prompt, work_prompt, system_prompt 字段"
)
logger.info("静态上下文已加载")
def _save_static_context(self):
os.makedirs(PATHS["static_context_dir"], exist_ok=True)
static_file = PATHS["static_context_dir"] / "prompts.json"
with open(static_file, "w", encoding="utf-8") as f:
json.dump(self.static_context, f, ensure_ascii=False, indent=2)
def save_static_context(self, context_type: str, content: str):
self.static_context[context_type] = content
self._save_static_context()
def load_static_context(self, context_type: str) -> str:
return self.static_context.get(context_type, "")
def get_kg_data(self) -> Dict:
"""获取知识图谱数据(实体和关系)"""
if self.kag_solver:
return self.kag_solver.get_kg_data()
return {
"entities": [],
"relations": [],
"entity_count": 0,
"relation_count": 0,
"error": "KAG推理器未初始化"
}
def _extract_keywords(self, query: str) -> List[str]:
"""从查询中提取关键词:中文词块、数字、下划线、工具名等"""
keywords = []
chinese_words = re.findall(r'[\u4e00-\u9fff]+', query)
keywords.extend(chinese_words)
numbers = re.findall(r'\d+', query)
keywords.extend(numbers)
tool_names = [
"buffer_filter_tool", "elevation_filter_tool",
"vegetation_filter_tool", "slope_filter_tool"
]
for tool in tool_names:
if tool in query.lower():
keywords.append(tool)
unit_names = [
"轻步兵", "重装步兵", "机械化步兵", "坦克部队", "反坦克步兵",
"自行火炮", "牵引火炮", "防空部队", "狙击手", "特种部队",
"装甲侦察单位", "工兵部队", "后勤保障部队", "指挥单位", "无人机侦察控制单元"
]
for unit in unit_names:
if unit in query:
keywords.append(unit)
return list(set(keywords))
def _calculate_keyword_score(self, doc_text: str, keywords: List[str]) -> float:
"""计算文档的关键词匹配分数"""
if not keywords:
return 0.0
doc_lower = doc_text.lower()
score = 0.0
for keyword in keywords:
if keyword.isdigit() or "tool" in keyword.lower():
weight = 2.0
else:
weight = 1.0
count = doc_lower.count(keyword.lower())
score += count * weight
return score / len(keywords) if keywords else 0.0
def _calculate_metadata_boost(self, metadata: Dict, query: str, keywords: List[str]) -> float:
"""计算元数据匹配加分:unit/type/tool强约束"""
boost = 0.0
if not metadata:
return boost
unit_in_meta = metadata.get("unit", "")
if unit_in_meta:
if unit_in_meta in query or unit_in_meta in keywords:
boost += KAG_CONFIG["metadata_boost_unit"]
type_in_meta = metadata.get("type", "")
if type_in_meta:
if "部署" in query or "配置" in query:
if type_in_meta == "deployment_rule":
boost += KAG_CONFIG["metadata_boost_type"]
if "射程" in query:
if type_in_meta == "equipment_info":
boost += KAG_CONFIG["metadata_boost_type"]
tool_in_meta = metadata.get("tool", "")
if tool_in_meta:
if tool_in_meta in query or tool_in_meta in keywords:
boost += KAG_CONFIG["metadata_boost_type"]
return boost
def _retrieve_from_kag(
self,
query: str,
query_embedding: List[float],
oversample: int
) -> List[Dict]:
"""从KAG知识图谱中检索候选实体(使用KAG推理器)"""
# 使用新的KAG推理器进行检索
if self.kag_solver:
try:
result = self.kag_solver.query(query)
# 记录本次KAG推理的完整结果,供plan阶段使用(避免重复调用KAG)
try:
self.last_kag_answer = result.get("answer") or result.get("raw_result") or ""
self.last_kag_input_query = result.get("input_query", query)
self.last_kag_tasks = result.get("tasks", [])
except Exception:
self.last_kag_answer = ""
self.last_kag_input_query = query
self.last_kag_tasks = []
# 将KAG推理结果转换为检索格式
candidates = []
# 如果有references,使用references
references = result.get("references", [])
if references and len(references) > 0:
for ref in references:
candidates.append({
"text": ref.get("text", ""),
"metadata": ref.get("metadata", {}),
"distance": ref.get("distance", 0.0)
})
if not candidates and result.get("answer"):
answer = result.get("answer", "")
import re
clean_answer = re.sub(r'<reference[^>]*></reference>', '', answer)
clean_answer = clean_answer.strip()
if clean_answer:
candidates.append({
"text": clean_answer,
"metadata": {
"source": "kag_reasoning",
"type": "answer",
"query": query
},
"distance": 0.0 # KAG推理结果,距离设为0表示高相关性
})
logger.info(f"[KAG] 使用推理答案作为上下文,长度: {len(clean_answer)}")
# 如果有raw_result且是字符串,也尝试使用
if not candidates and result.get("raw_result"):
raw_result = result.get("raw_result")
if isinstance(raw_result, str) and raw_result.strip():
import re
clean_text = re.sub(r'<reference[^>]*></reference>', '', raw_result)
clean_text = clean_text.strip()
if clean_text:
candidates.append({
"text": clean_text,
"metadata": {
"source": "kag_reasoning",
"type": "raw_result",
"query": query
},
"distance": 0.0
})
logger.info(f"[KAG] 使用原始结果作为上下文,长度: {len(clean_text)}")
return candidates
except Exception as e:
logger.warning(f"KAG推理器检索失败: {e}", exc_info=True)
# 如果KAG推理器不可用,返回空结果
logger.warning("KAG推理器未初始化,返回空结果")
return []
def load_dynamic_context(self, query: str, top_k: int = None, use_cache: bool = True) -> List[Dict]:
"""
从KAG知识图谱检索上下文
Args:
query: 查询文本
top_k: 返回结果数量,默认使用配置值
use_cache: 是否使用缓存(如果查询相同,直接返回缓存结果)
Returns:
检索到的上下文列表
"""
# 空查询直接跳过,避免无意义的KAG调用
if not query or not str(query).strip():
logger.warning("[KAG检索] 空查询已跳过")
return []
# 如果使用缓存且查询相同,直接返回缓存结果
if use_cache and query == self._last_query and self._last_context:
logger.info(f"[KAG检索] 使用缓存结果,query='{query[:50]}...'")
return self._last_context
if top_k is None:
top_k = KAG_CONFIG["top_k"]
logger.info(f"[KAG检索] query='{query}'")
keywords = self._extract_keywords(query)
logger.info(f"[KAG关键词] extracted keywords={keywords}")
query_embedding = self._embed_query(query)
all_candidates = self._retrieve_from_kag(
query, query_embedding, KAG_CONFIG["oversample"]
)
logger.info(f"[KAG召回] candidates={len(all_candidates)}")
if not all_candidates:
logger.warning("[KAG] 未找到任何候选文档")
return []
max_distance = KAG_CONFIG["max_distance"]
w_sem = KAG_CONFIG["w_sem"]
w_kw = KAG_CONFIG["w_kw"]
scored_candidates = []
for candidate in all_candidates:
distance = candidate["distance"]
if distance > max_distance:
continue
semantic_score = 1.0 - distance
keyword_score = self._calculate_keyword_score(candidate["text"], keywords)
metadata_boost = self._calculate_metadata_boost(
candidate["metadata"], query, keywords
)
final_score = w_sem * semantic_score + w_kw * keyword_score + metadata_boost
scored_candidates.append({
**candidate,
"semantic_score": semantic_score,
"keyword_score": keyword_score,
"metadata_boost": metadata_boost,
"final_score": final_score
})
scored_candidates.sort(key=lambda x: x["final_score"], reverse=True)
logger.info(f"[KAG过滤] 阈值过滤前={len(all_candidates)}, 过滤后={len(scored_candidates)}")
min_k = KAG_CONFIG["min_k"]
if len(scored_candidates) < min_k and len(all_candidates) > len(scored_candidates):
relaxed_distance_increment = KAG_CONFIG.get("relaxed_distance_increment", 0.5)
relaxed_max_distance = max_distance + relaxed_distance_increment
logger.warning(f"[KAG降级] 结果不足{min_k}条,放宽阈值至{relaxed_max_distance}")
for candidate in all_candidates:
if candidate["distance"] <= relaxed_max_distance:
if not any(c["text"] == candidate["text"] for c in scored_candidates):
distance = candidate["distance"]
semantic_score = 1.0 - distance
keyword_score = self._calculate_keyword_score(candidate["text"], keywords)
metadata_boost = self._calculate_metadata_boost(
candidate["metadata"], query, keywords
)
final_score = w_sem * semantic_score + w_kw * keyword_score + metadata_boost
scored_candidates.append({
**candidate,
"semantic_score": semantic_score,
"keyword_score": keyword_score,
"metadata_boost": metadata_boost,
"final_score": final_score,
"low_confidence": True
})
scored_candidates.sort(key=lambda x: x["final_score"], reverse=True)
final_results = scored_candidates[:top_k]
logger.info(f"[KAG最终结果] 返回{len(final_results)}条:")
for i, result in enumerate(final_results):
logger.info(
f" [{i+1}] distance={result['distance']:.3f}, "
f"semantic={result['semantic_score']:.3f}, "
f"keyword={result['keyword_score']:.3f}, "
f"metadata_boost={result['metadata_boost']:.3f}, "
f"final={result['final_score']:.3f}, "
f"low_confidence={result.get('low_confidence', False)}"
)
contexts = []
for result in final_results:
contexts.append({
"text": result["text"],
"metadata": result["metadata"],
"distance": result["distance"],
"semantic_score": result["semantic_score"],
"keyword_score": result["keyword_score"],
"metadata_boost": result["metadata_boost"],
"final_score": result["final_score"],
"low_confidence": result.get("low_confidence", False)
})
# 保存缓存(避免重复调用KAG)
if use_cache:
self._last_query = query
self._last_context = contexts
return contexts
def query_with_kag_reasoning(self, question: str) -> Dict:
"""
使用KAG推理能力回答问题
Args:
question: 用户问题
Returns:
包含答案、引用和检索原文的字典
"""
if not self.kag_solver:
logger.warning("KAG推理器未初始化")
return {
"answer": "",
"references": [],
"source_texts": [],
"error": "KAG推理器未初始化"
}
try:
# 直接调用KAG推理,不先获取上下文(避免两次推理)
result = self.kag_solver.query(question)
# 从tasks中提取检索到的原文(chunks)
source_texts = []
tasks = result.get("tasks", [])
for task in tasks:
task_memory = task.get("memory", {})
# 从memory中提取retriever结果
if "retriever" in task_memory:
retriever_output = task_memory["retriever"]
# 处理RetrieverOutput对象(可能已经被序列化为字典)
if isinstance(retriever_output, dict):
# 如果是字典格式(序列化后的RetrieverOutput)
chunks = retriever_output.get("chunks", [])
for chunk in chunks:
if isinstance(chunk, dict):
# chunk.to_dict() 返回的格式
content = chunk.get("content", chunk.get("text", ""))
if not content:
# 尝试从其他字段获取
content = chunk.get("chunk_content", chunk.get("desc", str(chunk)))
if content:
source_texts.append({
"text": content,
"metadata": chunk.get("metadata", chunk.get("chunk_metadata", {})),
"source": "kag_retriever"
})
elif hasattr(retriever_output, "chunks"):
# 如果是RetrieverOutput对象
for chunk in retriever_output.chunks:
if hasattr(chunk, "content"):
source_texts.append({
"text": chunk.content,
"metadata": chunk.metadata if hasattr(chunk, "metadata") else {},
"source": "kag_retriever"
})
elif hasattr(chunk, "to_dict"):
chunk_dict = chunk.to_dict()
content = chunk_dict.get("content", chunk_dict.get("text", ""))
if content:
source_texts.append({
"text": content,
"metadata": chunk_dict.get("metadata", {}),
"source": "kag_retriever"
})
# 也检查task.result中是否有RetrieverOutput
task_result = task.get("result")
if task_result:
if isinstance(task_result, dict):
# 如果是字典格式
if "chunks" in task_result:
chunks = task_result["chunks"]
for chunk in chunks:
if isinstance(chunk, dict):
content = chunk.get("content", chunk.get("text", ""))
if content:
source_texts.append({
"text": content,
"metadata": chunk.get("metadata", {}),
"source": "kag_retriever"
})
elif hasattr(task_result, "chunks"):
# 如果是RetrieverOutput对象
for chunk in task_result.chunks:
if hasattr(chunk, "content"):
source_texts.append({
"text": chunk.content,
"metadata": chunk.metadata if hasattr(chunk, "metadata") else {},
"source": "kag_retriever"
})
# 去重(基于文本内容)
seen_texts = set()
unique_source_texts = []
for source in source_texts:
text = source.get("text", "")
if text and text not in seen_texts:
seen_texts.add(text)
unique_source_texts.append(source)
# 添加source_texts到结果中
result["source_texts"] = source_texts
return result
except Exception as e:
logger.error(f"KAG推理查询失败: {e}", exc_info=True)
return {
"answer": "",
"references": [],
"source_texts": [],
"error": str(e)
}