Skip to content

Commit 1acbe8d

Browse files
committed
feat: 文档状态70%
1 parent 119bba0 commit 1acbe8d

File tree

20 files changed

+483
-113
lines changed

20 files changed

+483
-113
lines changed

apps/common/db/search.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from django.db.models import QuerySet
1313

1414
from common.db.compiler import AppSQLCompiler
15-
from common.db.sql_execute import select_one, select_list
15+
from common.db.sql_execute import select_one, select_list, update_execute
1616
from common.response.result import Page
1717

1818

@@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
109109
return select_list(exec_sql, exec_params)
110110

111111

112+
def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
113+
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
114+
with_table_name=False):
115+
"""
116+
复杂查询
117+
:param with_table_name: 生成sql是否包含表名
118+
:param queryset: 查询条件构造器
119+
:param select_string: 查询前缀 不包括 where limit 等信息
120+
:param field_replace_dict: 需要替换的字段
121+
:return: 查询结果
122+
"""
123+
if isinstance(queryset, Dict):
124+
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
125+
else:
126+
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
127+
return update_execute(exec_sql, exec_params)
128+
129+
112130
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
113131
"""
114132
分页查询

apps/common/event/listener_manage.py

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,18 @@
1313
from typing import List
1414

1515
import django.db.models
16+
from django.db import models
1617
from django.db.models import QuerySet
18+
from django.db.models.functions import Substr, Reverse
1719
from langchain_core.embeddings import Embeddings
1820

1921
from common.config.embedding_config import VectorStore
20-
from common.db.search import native_search, get_dynamics_model
21-
from common.event.common import embedding_poxy
22+
from common.db.search import native_search, get_dynamics_model, native_update
23+
from common.db.sql_execute import sql_execute, update_execute
2224
from common.util.file_util import get_file_content
2325
from common.util.lock import try_lock, un_lock
24-
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
26+
from common.util.page_utils import page
27+
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
2528
from embedding.models import SourceType, SearchMode
2629
from smartdoc.conf import PROJECT_DIR
2730

@@ -114,7 +117,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
114117
@param embedding_model: 向量模型
115118
"""
116119
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
117-
status = Status.success
120+
# 更新到开始状态
121+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
118122
try:
119123
data_list = native_search(
120124
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
@@ -125,23 +129,89 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
125129
# 删除段落
126130
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
127131

128-
def is_save_function():
129-
return QuerySet(Paragraph).filter(id=paragraph_id).exists()
132+
def is_the_task_interrupted():
133+
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
134+
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
135+
return True
136+
return False
130137

131138
# 批量向量化
132-
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
139+
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
140+
# 更新到开始状态
141+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
142+
State.SUCCESS)
133143
except Exception as e:
134144
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
135-
status = Status.error
145+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
146+
State.FAILURE)
136147
finally:
137-
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
138148
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
139149

140150
@staticmethod
141151
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
142152
# 批量向量化
143153
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
144154

155+
@staticmethod
156+
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
157+
def embedding_paragraph_apply(paragraph_list):
158+
for paragraph in paragraph_list:
159+
if is_the_task_interrupted():
160+
break
161+
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
162+
post_apply()
163+
164+
return embedding_paragraph_apply
165+
166+
@staticmethod
167+
def get_aggregation_document_status(document_id):
168+
def aggregation_document_status():
169+
sql = get_file_content(
170+
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
171+
update_execute(sql, [document_id, document_id])
172+
173+
return aggregation_document_status
174+
175+
@staticmethod
176+
def post_update_document_status(document_id, task_type: TaskType):
177+
_document = QuerySet(Document).filter(id=document_id).first()
178+
179+
status = Status(_document.status)
180+
if status[task_type] == State.REVOKE:
181+
status[task_type] = State.REVOKED
182+
else:
183+
status[task_type] = State.SUCCESS
184+
for item in _document.status_meta.get('aggs', []):
185+
agg_status = item.get('status')
186+
agg_count = item.get('count')
187+
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
188+
status[task_type] = State.FAILURE
189+
_document.status = status.__str__()
190+
_document.save()
191+
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
192+
reversed_status=Reverse('status'),
193+
task_type_status=Substr('reversed_status', task_type.value,
194+
task_type.value),
195+
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
196+
task_type,
197+
State.REVOKED)
198+
199+
@staticmethod
200+
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
201+
exec_sql = get_file_content(
202+
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql'))
203+
bit_number = len(TaskType)
204+
up_index = taskType.value - 1
205+
next_index = taskType.value + 1
206+
status_number = state.value
207+
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
208+
'${status_number}': status_number, '${next_index}': next_index,
209+
'${table_name}': query_set.model._meta.db_table}
210+
for key in params_dict:
211+
_value_ = params_dict[key]
212+
exec_sql = exec_sql.replace(key, str(_value_))
213+
native_update(query_set, exec_sql)
214+
145215
@staticmethod
146216
def embedding_by_document(document_id, embedding_model: Embeddings):
147217
"""
@@ -153,33 +223,28 @@ def embedding_by_document(document_id, embedding_model: Embeddings):
153223
if not try_lock('embedding' + str(document_id)):
154224
return
155225
max_kb.info(f"开始--->向量化文档:{document_id}")
156-
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
157-
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
158-
status = Status.success
226+
# 批量修改状态为PADDING
227+
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED)
159228
try:
160-
data_list = native_search(
161-
{'problem': QuerySet(
162-
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
163-
**{'paragraph.document_id': document_id}),
164-
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
165-
select_string=get_file_content(
166-
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
167229
# 删除文档向量数据
168230
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
169231

170-
def is_save_function():
171-
return QuerySet(Document).filter(id=document_id).exists()
172-
173-
# 批量向量化
174-
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
232+
def is_the_task_interrupted():
233+
document = QuerySet(Document).filter(id=document_id).first()
234+
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
235+
return True
236+
return False
237+
238+
# 根据段落进行向量化处理
239+
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 10,
240+
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
241+
ListenerManagement.get_aggregation_document_status(
242+
document_id)),
243+
is_the_task_interrupted)
175244
except Exception as e:
176245
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
177-
status = Status.error
178246
finally:
179-
# 修改状态
180-
QuerySet(Document).filter(id=document_id).update(
181-
**{'status': status, 'update_time': datetime.datetime.now()})
182-
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
247+
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
183248
max_kb.info(f"结束--->向量化文档:{document_id}")
184249
un_lock('embedding' + str(document_id))
185250

apps/common/util/page_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: page_utils.py
6+
@date:2024/11/21 10:32
7+
@desc:
8+
"""
9+
from math import ceil
10+
11+
12+
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
13+
"""
14+
15+
@param query_set: 查询query_set
16+
@param page_size: 每次查询大小
17+
@param handler: 数据处理器
18+
@param is_the_task_interrupted: 任务是否被中断
19+
@return:
20+
"""
21+
count = query_set.count()
22+
for i in range(0, ceil(count / page_size)):
23+
if is_the_task_interrupted():
24+
return
25+
offset = i * page_size
26+
paragraph_list = query_set[offset: offset + page_size]
27+
handler(paragraph_list)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Generated by Django 4.2.15 on 2024-11-22 14:44
2+
3+
import dataset.models.data_set
4+
from django.db import migrations, models
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('dataset', '0010_file_meta'),
11+
]
12+
13+
operations = [
14+
migrations.AddField(
15+
model_name='document',
16+
name='status_meta',
17+
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'),
18+
),
19+
migrations.AddField(
20+
model_name='paragraph',
21+
name='status_meta',
22+
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'),
23+
),
24+
migrations.AlterField(
25+
model_name='document',
26+
name='status',
27+
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
28+
),
29+
migrations.AlterField(
30+
model_name='paragraph',
31+
name='status',
32+
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
33+
),
34+
]

apps/dataset/models/data_set.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
@desc: 数据集
88
"""
99
import uuid
10+
from enum import Enum
1011

1112
from django.db import models
1213
from django.db.models.signals import pre_delete
@@ -18,13 +19,60 @@
1819
from users.models import User
1920

2021

21-
class Status(models.TextChoices):
22-
"""订单类型"""
23-
embedding = 0, '导入中'
24-
success = 1, '已完成'
25-
error = 2, '导入失败'
26-
queue_up = 3, '排队中'
27-
generating = 4, '生成问题中'
22+
class TaskType(Enum):
23+
# 向量
24+
EMBEDDING = 1
25+
# 生成问题
26+
GENERATE_PROBLEM = 2
27+
# 同步
28+
SYNC = 3
29+
30+
31+
class State(Enum):
32+
# 等待
33+
PENDING = '0'
34+
# 执行中
35+
STARTED = '1'
36+
# 成功
37+
SUCCESS = '2'
38+
# 失败
39+
FAILURE = '3'
40+
# 取消任务
41+
REVOKE = '4'
42+
# 取消成功
43+
REVOKED = '5'
44+
45+
46+
class Status:
47+
type_cls = TaskType
48+
state_cls = State
49+
50+
def __init__(self, status: str = None):
51+
self.task_status = {}
52+
status_list = list(status[::-1] if status is not None else '')
53+
for _type in self.type_cls:
54+
index = _type.value - 1
55+
_state = self.state_cls(status_list[index] if len(status_list) > index else '2')
56+
self.task_status[_type] = _state
57+
58+
@staticmethod
59+
def of(status: str):
60+
return Status(status)
61+
62+
def __str__(self):
63+
result = []
64+
for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True):
65+
result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value)
66+
return ''.join(result)
67+
68+
def __setitem__(self, key, value):
69+
self.task_status[key] = value
70+
71+
def __getitem__(self, item):
72+
return self.task_status[item]
73+
74+
def update_status(self, task_type: TaskType, state: State):
75+
self.task_status[task_type] = state
2876

2977

3078
class Type(models.TextChoices):
@@ -42,6 +90,10 @@ def default_model():
4290
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
4391

4492

93+
def default_status_meta():
94+
return {"state_time": {}}
95+
96+
4597
class DataSet(AppModelMixin):
4698
"""
4799
数据集表
@@ -68,8 +120,8 @@ class Document(AppModelMixin):
68120
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
69121
name = models.CharField(max_length=150, verbose_name="文档名称")
70122
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
71-
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
72-
default=Status.queue_up)
123+
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
124+
status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta)
73125
is_active = models.BooleanField(default=True)
74126

75127
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
@@ -94,8 +146,8 @@ class Paragraph(AppModelMixin):
94146
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
95147
content = models.CharField(max_length=102400, verbose_name="段落内容")
96148
title = models.CharField(max_length=256, verbose_name="标题", default="")
97-
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
98-
default=Status.embedding)
149+
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
150+
status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta)
99151
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
100152
is_active = models.BooleanField(default=True)
101153

@@ -145,7 +197,6 @@ class File(AppModelMixin):
145197

146198
meta = models.JSONField(verbose_name="文件关联数据", default=dict)
147199

148-
149200
class Meta:
150201
db_table = "file"
151202

@@ -161,7 +212,6 @@ def get_byte(self):
161212
return result['data']
162213

163214

164-
165215
@receiver(pre_delete, sender=File)
166216
def on_delete_file(sender, instance, **kwargs):
167217
select_one(f'SELECT lo_unlink({instance.loid})', [])

0 commit comments

Comments
 (0)