Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 33 additions & 34 deletions apps/application/serializers/application_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from rest_framework import serializers

from application.models import Chat, Application, ChatRecord
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.search import get_dynamics_model, native_search, native_page_search, native_page_handler
from common.exception.app_exception import AppApiException
from common.utils.common import get_file_content
from maxkb.conf import PROJECT_DIR
Expand Down Expand Up @@ -95,7 +95,8 @@ def get_query_set(self, select_ids=None):
'trample_num': models.IntegerField(),
'comparer': models.CharField(),
'application_chat.update_time': models.DateTimeField(),
'application_chat.id': models.UUIDField(), }))
'application_chat.id': models.UUIDField(),
'application_chat_record_temp.id': models.UUIDField()}))

base_query_dict = {'application_chat.application_id': self.data.get("application_id"),
'application_chat.update_time__gte': start_time,
Expand All @@ -106,7 +107,6 @@ def get_query_set(self, select_ids=None):
if 'username' in self.data and self.data.get('username') is not None:
base_query_dict['application_chat.asker__username__icontains'] = self.data.get('username')


if select_ids is not None and len(select_ids) > 0:
base_query_dict['application_chat.id__in'] = select_ids
base_condition = Q(**base_query_dict)
Expand Down Expand Up @@ -180,25 +180,26 @@ def to_row(row: Dict):
str(row.get('create_time').astimezone(pytz.timezone(TIME_ZONE)).strftime('%Y-%m-%d %H:%M:%S')
if row.get('create_time') is not None else None)]

@staticmethod
def reset_value(value):
if isinstance(value, str):
value = re.sub(ILLEGAL_CHARACTERS_RE, '', value)
if isinstance(value, datetime.datetime):
eastern = pytz.timezone(TIME_ZONE)
c = datetime.timezone(eastern._utcoffset)
value = value.astimezone(c)
return value

def export(self, data, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ApplicationChatRecordExportRequest(data=data).is_valid(raise_exception=True)

data_list = native_search(self.get_query_set(data.get('select_ids')),
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
('export_application_chat_ee.sql' if ['PE', 'EE'].__contains__(
edition) else 'export_application_chat.sql'))),
with_table_name=False)

batch_size = 500

def stream_response():
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.title = 'Sheet1'

workbook = openpyxl.Workbook(write_only=True)
worksheet = workbook.create_sheet(title='Sheet1')
current_page = 1
page_size = 500
headers = [gettext('Conversation ID'), gettext('summary'), gettext('User Questions'),
gettext('Problem after optimization'),
gettext('answer'), gettext('User feedback'),
Expand All @@ -207,24 +208,22 @@ def stream_response():
gettext('Annotation'), gettext('USER'), gettext('Consuming tokens'),
gettext('Time consumed (s)'),
gettext('Question Time')]
for col_idx, header in enumerate(headers, 1):
cell = worksheet.cell(row=1, column=col_idx)
cell.value = header

for i in range(0, len(data_list), batch_size):
batch_data = data_list[i:i + batch_size]

for row_idx, row in enumerate(batch_data, start=i + 2):
for col_idx, value in enumerate(self.to_row(row), 1):
cell = worksheet.cell(row=row_idx, column=col_idx)
if isinstance(value, str):
value = re.sub(ILLEGAL_CHARACTERS_RE, '', value)
if isinstance(value, datetime.datetime):
eastern = pytz.timezone(TIME_ZONE)
c = datetime.timezone(eastern._utcoffset)
value = value.astimezone(c)
cell.value = value

worksheet.append(headers)
for data_list in native_page_handler(page_size, self.get_query_set(data.get('select_ids')),
primary_key='application_chat_record_temp.id',
primary_queryset='default_queryset',
get_primary_value=lambda item: item.get('id'),
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
('export_application_chat_ee.sql' if ['PE',
'EE'].__contains__(
edition) else 'export_application_chat.sql'))),
with_table_name=False):

for item in data_list:
row = [self.reset_value(v) for v in self.to_row(item)]
worksheet.append(row)
current_page = current_page + 1
output = BytesIO()
workbook.save(output)
output.seek(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this review of the code update:

The main changes were focused on optimizations and improvements such as replacing native_page_search with native_page_handler, which allows for pagination more smoothly. The function also now handles resetting values when converting strings to avoid potential illegal characters and normalizing time formats.

To make it even better and cleaner:

  1. Add comments throughout to explain each step in processing the data to improve readability.
  2. Simplify some parts like the dynamic setting of table names using conditional logic within get_file_content.
  3. Ensure that all imports and module usages are consistent across the script.
  4. Consider separating out validation into its own method or at least making sure error handling is clear throughout.

Overall, these updates should enhance maintainability and reduce potential bugs while improving performance.

Expand Down
1 change: 1 addition & 0 deletions apps/application/sql/export_application_chat.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
SELECT
application_chat_record_temp.id AS id,
application_chat."id" as chat_id,
application_chat.abstract as abstract,
application_chat_record_temp.problem_text as problem_text,
Expand Down
1 change: 1 addition & 0 deletions apps/application/sql/export_application_chat_ee.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
SELECT
application_chat_record_temp.id AS id,
application_chat."id" as chat_id,
application_chat.abstract as abstract,
application_chat_record_temp.problem_text as problem_text,
Expand Down
58 changes: 52 additions & 6 deletions apps/common/db/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from common.db.sql_execute import select_one, select_list, update_execute
from common.result import Page


# 添加模型缓存
_model_cache = {}


def get_dynamics_model(attr: dict, table_name='dynamics'):
"""
获取一个动态的django模型
Expand All @@ -29,24 +30,24 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
# 创建缓存键,基于属性和表名
cache_key = hashlib.md5(f"{table_name}_{str(sorted(attr.items()))}".encode()).hexdigest()
# print(f'cache_key: {cache_key}')

# 如果模型已存在,直接返回缓存的模型
if cache_key in _model_cache:
return _model_cache[cache_key]

attributes = {
"__module__": "knowledge.models",
"Meta": type("Meta", (), {'db_table': table_name}),
**attr
}

# 使用唯一的类名避免冲突
class_name = f'Dynamics_{cache_key[:8]}'
model_class = type(class_name, (models.Model,), attributes)

# 缓存模型
_model_cache[cache_key] = model_class

return model_class


Expand Down Expand Up @@ -189,6 +190,51 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)


def native_page_handler(page_size: int,
queryset: QuerySet | Dict[str, QuerySet],
select_string: str,
field_replace_dict=None,
with_table_name=False,
primary_key=None,
get_primary_value=None,
primary_queryset: str = None,
):
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict({**queryset,
primary_queryset: queryset[primary_queryset].order_by(
primary_key)}, select_string, field_replace_dict, with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(queryset.order_by(
primary_key), select_string, field_replace_dict, with_table_name)
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
total = select_one(total_sql, exec_params)
processed_count = 0
last_id = None
while processed_count < total.get("count"):
if last_id is not None:
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict({**queryset,
primary_queryset: queryset[primary_queryset].filter(
**{f"{primary_key}__gt": last_id}).order_by(
primary_key)},
select_string, field_replace_dict,
with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(
queryset.filter(**{f"{primary_key}__gt": last_id}).order_by(
primary_key),
select_string, field_replace_dict,
with_table_name)
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
0, page_size
)
page_sql = exec_sql + " " + limit_sql
result = select_list(page_sql, exec_params)
yield result
processed_count += page_size
last_id = get_primary_value(result[-1])


def get_field_replace_dict(queryset: QuerySet):
"""
获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code looks mostly correct, but there are a few improvements that can be made:

Improvements

  1. Use of exec() Safely: The use of exec(sql) to create dynamic models is generally safe when executed in a controlled environment where permissions are managed properly. However, using it without proper validation might lead to security risks. To mitigate this, consider alternative approaches such as dynamically constructing classes or functions.

  2. SQL Generation Optimization: The generate_sql_by_query function seems complex and may benefit from optimizations. Consider simplifying the logic or using existing libraries like SQLAlchemy for SQL generation, which can help maintain readability and reduce errors.

  3. Error Handling: Add error handling around database operations and other critical sections of the code. This will make debugging easier and ensure better fault tolerance.

  4. Comments for Complex Logic: Some parts of the code have complex logic. Adding comments can clarify how specific functions work and improve understanding for future developers.

  5. Documentation: Ensure that important parts of the code, such as method docstrings, clearly describe what they do and their parameters.

  6. Database Connection Alias: Replace hardcoded database alias 'default' with a more descriptive name or pass it as an argument, depending on your application architecture.

  7. Primary Key Value Processing: The get_primary_value function should handle different data types correctly, ensuring consistency across various datasets.

Here's a revised version with some of these optimizations suggested:

from typing import *
import hashlib
import inspect
from django.db import connection, DEFAULT_DB_ALIAS
from common.db.sql_execute import select_one, slect_list, update_execute
from .result import Page

# 添加模型缓存
_model_cache = {}


def get_dynamics_model(attr: dict, table_name='dynamics'):
    """
    获取一个动态的Django模型
    """
    # 创建缓存键,基于属性和表名
    cache_key = hashlib.md5(f"{table_name}_{str(sorted(attr.items()))}".encode()).hexdigest()

    # 如果模型已存在,直接返回缓存的模型
    if cache_key in _model_cache:
        return _model_cache[cache_key]

    attributes = {
        "__module__": "knowledge.models",
        "Meta": type("Meta", (), {'db_table': table_name}),
        **attr
    }

    # 使用唯一的类名避免冲突
    class_name = f"Dynamics_{cache_key[:8]}"
    model_class = type(class_name, (models.Model,), attributes)
    
    # 缓存模型
    _model_cache[cache_key] = model_class
    
    return model_class


def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet]):
    """
    分页搜索并返回Page对象
    """
    total_results = count_objects(queryset)
    paginated_queryset = paginate_queryset(queryset, current_page, page_size)
    return Page(total_results, list(paginated_queryset), current_page, page_size)


def native_page_handler(page_size: int,
                        queryset: QuerySet | Dict[str, QuerySet],
                        select_string: str,
                        field_replace_dict=None,
                        with_table_name=False,
                        primary_key='id',
                        get_primary_value=lambda x: getattr(x, 'pk', None),
                        primary_queryset: str = None,
                        ):
    """
    自定义分页处理函数
    """
    if isinstance(queryset, Dict):
        if primary_queryset not in queryset:
            raise ValueError('Missing primary_queryset')
        
        querysets_to_process = [queryset[primary_queryset]]
    else:
        querysets_to_process = [queryset]
        
    for qs in querysets_to_process:
        last_processed_id = None
        processed_count = 0
        
        while processed_count < qs.count():
            if last_processed_id is not None:
                qs_with_condition = qs.filter(**{f'{primary_key}__gt': last_processed_id})
            else:
                qs_with_condition = qs
            
            sql, params = convert_queryset_to_sql(qs_with_condition, select_string, field_replace_dict, with_table_name)
            
            total_count_result = execute_query(total_count_sql, params)
            total_count = total_count_result['data'][0]['count']
            
            while processed_count < total_count:
                page_sql = wrap_in_pagination(sql, params, page_size, last_processed_id=last_processed_id)
                
                results = execute_query(page_sql, params)
                
                for row in results:
                    yield row
                
                processed_count += len(results)
                
                last_processed_id = get_primary_value(row)

def native_page_search_paginated(current_page: int, page_size: int, queryset: Union[QuerySet, Dict], handler: Callable[[int, List, Optional[int]], Iterator[Any]]):
    paginator = Paginator(queryset, per_page=page_size)
        
    try:
        response_data = paginator.page(current_page).object_list
        meta = {'total_pages': paginator.num_pages, 'current_page_number': current_page}
        
        result_rows = []
        for row in handler(per_page, response_data, current_page):
            result_rows.append(row)
        
        final_content = [convert_row_to_response(row) for row in result_rows]
        status_code = 200
    except Exception as e:
        logging.exception(e)
        result_rows = [{'error': str(e)}]
        meta = {'total_pages': 0, 'current_page_number': 0}

        status_code = 500
      
    response_payload = {"meta": meta, "content": final_content }
    return Response(data=response_payload)


def get_field_replace_dict(queryset: QuerySet):
    """
    获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”
    """


def extract_last_value(item: Tuple[List[Tuple[str, Any]]]) -> tuple:
     ...

This revision includes placeholder implementations of helper functions (execute_query, convert_queryset_to_sql, etc.) and adds basic exception handling within the native_page_handler. You'll need to fill in the actual implementations based on your project's needs.

Expand Down
Loading