Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions apps/application/flow/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"""
import asyncio
import json
import traceback
import queue
import threading
from typing import Iterator

from django.http import StreamingHttpResponse
Expand Down Expand Up @@ -242,6 +243,30 @@ def generate_tool_message_complete(name, input_content, output_content):
return tool_message_complete_template % (name, input_formatted, output_formatted)


# 全局单例事件循环
_global_loop = None
_loop_thread = None
_loop_lock = threading.Lock()


def get_global_loop():
"""获取全局共享的事件循环"""
global _global_loop, _loop_thread

with _loop_lock:
if _global_loop is None:
_global_loop = asyncio.new_event_loop()

def run_forever():
asyncio.set_event_loop(_global_loop)
_global_loop.run_forever()

_loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop")
_loop_thread.start()

return _global_loop


async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
client = MultiServerMCPClient(json.loads(mcp_servers))
tools = await client.get_tools()
Expand Down Expand Up @@ -279,19 +304,31 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_


def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
loop = asyncio.new_event_loop()
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
while True:
try:
chunk = loop.run_until_complete(anext_async(async_gen))
yield chunk
except StopAsyncIteration:
break
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
finally:
loop.close()
"""使用全局事件循环,不创建新实例"""
result_queue = queue.Queue()
loop = get_global_loop() # 使用共享循环

async def _run():
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
async for chunk in async_gen:
result_queue.put(('data', chunk))
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
result_queue.put(('error', e))
finally:
result_queue.put(('done', None))

# 在全局循环中调度任务
asyncio.run_coroutine_threadsafe(_run(), loop)

while True:
msg_type, data = result_queue.get()
if msg_type == 'done':
break
if msg_type == 'error':
raise data
yield data


async def anext_async(agen):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code has several improvements and optimizations:

Improvements/Enhancements:

  1. Global Event Loop Management:

    • Global events loop management is implemented using a thread-safe approach. This avoids creating multiple event loops which can lead to resource leaks.
  2. Error Handling on Exit:

    • Proper error handling is added when the coroutine exits (finally block). The exception details are still logged.
  3. Concurrency Improvements:

    • Instead of creating a new event loop for each response, the function uses an existing global one managed by a separate thread. This reduces overhead associated with repeatedly spinning up and closing event loops.
  4. Queueing Results:

    • All asynchronous tasks that produce results now use a queue.Queue() instead of returning data synchronously. This keeps memory usage more efficient while maintaining high throughput.

Optimizations:

  1. Avoid Repeated Logging:

    • Ensure consistent logging format and possibly adjust logs based on their severity level.
  2. Thread Safety Enhancements:

    • Use context managers or explicit locking mechanisms where applicable across different threads to prevent race conditions.
  3. Memory Efficiency:

    • Minimize the time spent managing resources (like locks) during critical sections of execution. If possible, consider offloading computationally intensive parts to worker processes.
  4. Logging Contexts:

    • Consider adding additional contextual information to log messages such as request ID, user IP, etc., especially when logging errors for debugging.

By addressing these points, you improve both performance and maintainability of the codebase. Always ensure compatibility across the specified cutoff date and test thoroughly before deployment.

Expand Down
76 changes: 57 additions & 19 deletions apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,25 +234,62 @@ def run_block(self, language='zh'):
非流式响应
@return: 结果
"""
self.run_chain_async(None, None, language)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = self.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self)
return self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
other_params={'answer_list': answer_list})
try:
self.run_chain_async(None, None, language)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = self.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self)

res = self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
other_params={'answer_list': answer_list})
finally:
self._cleanup()
return res

def _cleanup(self):
"""清理所有对象引用"""
# 清理列表
self.future_list.clear()
self.field_list.clear()
self.global_field_list.clear()
self.chat_field_list.clear()
self.image_list.clear()
self.video_list.clear()
self.document_list.clear()
self.audio_list.clear()
self.other_list.clear()
if hasattr(self, 'node_context'):
self.node_context.clear()

# 清理字典
self.context.clear()
self.chat_context.clear()
self.form_data.clear()

# 清理对象引用
self.node_chunk_manage = None
self.work_flow_post_handler = None
self.flow = None
self.start_node = None
self.current_node = None
self.current_result = None
self.chat_record = None
self.base_to_response = None
self.params = None
self.lock = None

def run_stream(self, current_node, node_result_future, language='zh'):
"""
Expand Down Expand Up @@ -307,6 +344,7 @@ def await_result(self):
'',
[],
'', True, message_tokens, answer_tokens, {})
self._cleanup()

def run_chain_async(self, current_node, node_result_future, language='zh'):
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed code has several improvements:

  1. Exception Handling: Wrapped the main logic with a try-except block to ensure that any unexpected errors during execution are caught and cleaned up properly.

  2. Cleanup Method: Created a _cleanup method specifically for cleaning up resources at the end of operations like running a chain asynchronously. This ensures all object references are cleared, which can prevent memory leaks.

  3. Docstring Updates: Updated comments throughout the methods to be more informative and consistent in format.

  4. Comments Removal: Removed unnecessary redundant comments within functions as they do not enhance readability but may confuse readers.

These changes improve the robustness of the code by handling exceptions and reducing resource pollution due to stale references.

Suggested Cleanup Code

To further streamline the cleanup process, you could also consider defining an abstract base class (NodeContext) for managing context objects across different nodes, so each subclass handles its own specific clean-ups. Here's how it might look:

class NodeContext(metaclass=abc.ABCMeta):
    def clear(self): pass

# Example implementation using threading.local for thread-specific contexts
import threading

current_context = threading.local()

def setup_thread_specific_context():
    global current_context
    current_context.node_chunk_manage = None
    current_context.work_flow_post_handler = None
    current_context.flow = None
    current_context.start_node = None
    current_context.current_node = None
    current_context.current_result = None
    current_context.chat_record = None
    current_context.base_to_response = None
    current_context.params = None
    current_context.lock = None

setup_thread_specific_context()

This approach centralizes context management, making it easier to handle various stateful objects across threads or asynchronous tasks.

Final Cleaned-up Code Snippet:

@staticmethod
def start_task(flow_id: str) -> Task:
    flow = Flow.objects.filter(id=flow_id).first()
    if not flow:
        raise ValueError(f"Flow with id {flow_id} does not exist.")
    
    return cls(flow)

async def run_async(self, async_loop=None, language="zh"):
    try:
        start = time.time()
        self._start_time = start
        
        await self._set_start_info(async_loop, "run_async", False)
        
        # Additional steps...
        
        while True:
            await asyncio.sleep(0.1)
            if not self.running and all(not f.done() for f in getattr(self, 'future', [])):
                break
        
        result_text, info_dict = [], []
        
        # Handle intermediate messages (e.g., step-by-step outputs)
        
        # Generate final response
        resp = self._to_api_resp("final_response", "", "", {}, "")
        
        # Await results from background tasks
        await asyncio.gather(*[f.result() for f in getattr(self, 'future', [])])
        
        print(f'Total execution time took {time.time() - start:.2f}s')
        
        return resp
    
    finally:
        self._end_time = round(time.time(), 4)
        total_cost = int((self._execution_total / len(self.output_queue)) * 1000) if self.output_queue else 0
        
        logger.info(f"{task_type.upper()}, cost:{total_cost}ms")
        
        task_manager.save_task(self.task)
        setattr(self.task, '_id', self.id)
        self.save()
        
        self._clear_all_objects()
        

def _clear_all_objects(self):
    if hasattr(current_context, 'result'):
        del current_context.result

Note: The above cleaned-up code assumes the existence of relevant classes and modules such as Task, Flow, output_queue, among others. These need to be implemented or referenced correctly based on your actual system architecture.

Expand Down
61 changes: 58 additions & 3 deletions apps/application/serializers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
@date:2025/6/9 13:42
@desc:
"""
from datetime import datetime
from typing import List

from django.core.cache import cache
Expand Down Expand Up @@ -226,10 +225,66 @@ def append_chat_record(self, chat_record: ChatRecord):
chat_record.save()
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()

def to_dict(self):

return {
'chat_id': self.chat_id,
'chat_user_id': self.chat_user_id,
'chat_user_type': self.chat_user_type,
'knowledge_id_list': self.knowledge_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'application_id': self.application_id,
'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list],
'debug': self.debug
}

def chat_record_to_map(self, chat_record):
return {'id': chat_record.id,
'chat_id': chat_record.chat_id,
'vote_status': chat_record.vote_status,
'problem_text': chat_record.problem_text,
'answer_text': chat_record.answer_text,
'answer_text_list': chat_record.answer_text_list,
'message_tokens': chat_record.message_tokens,
'answer_tokens': chat_record.answer_tokens,
'const': chat_record.const,
'details': chat_record.details,
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
'run_time': chat_record.run_time,
'index': chat_record.index}

@staticmethod
def map_to_chat_record(chat_record_dict):
ChatRecord(id=chat_record_dict.get('id'),
chat_id=chat_record_dict.get('chat_id'),
vote_status=chat_record_dict.get('vote_status'),
problem_text=chat_record_dict.get('problem_text'),
answer_text=chat_record_dict.get('answer_text'),
answer_text_list=chat_record_dict.get('answer_text_list'),
message_tokens=chat_record_dict.get('message_tokens'),
answer_tokens=chat_record_dict.get('answer_tokens'),
const=chat_record_dict.get('const'),
details=chat_record_dict.get('details'),
improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
run_time=chat_record_dict.get('run_time'),
index=chat_record_dict.get('index'), )

def set_cache(self):
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
version=Cache_Version.CHAT_INFO.get_version(),
timeout=60 * 30)

@staticmethod
def map_to_chat_info(chat_info_dict):
return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'),
chat_info_dict.get('exclude_document_id_list'),
chat_info_dict.get('application_id'),
[ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')])

@staticmethod
def get_cache(chat_id):
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version())
if chat_info_dict:
return ChatInfo.map_to_chat_info(chat_info_dict)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks generally clean and follows Python conventions. Here are some minor points for consideration:

Potential Issues:

  1. String Formatting in @dataclass Decorator: The string format is not compatible with dataclasses.dataclass. Use 'YYYY-MM-DD HH:MM' instead of the full date format.

  2. Empty Lists:

    exclude_document_id_list=self.exclude_document_id_list

    Ensure that exclude_document_id_list is properly initialized before passing it to the constructor.

  3. Static Method Parameters:
    The static method map_to_chat_record should use keyword parameters (chat_record_dict) without default values, but you can add an optional parameter like this:

    @staticmethod
    def map_to_chat_record(chat_record_dict=None):
        if chat_record_dict is None:
            chat_record_dict = {}
        return ChatRecord(
            id=chat_record_dict.get('id'):
            chat_id=chat_record_dict.get('chat_id'),
            vote_status=chat_record_dict.get('vote_status'),
            problem_text=chat_record_dict.get('problem_text'),
            answer_text=chat_record_dict.get('answer_text'),
            answer_text_list=chat_record_dict.get('answer_text_list'),
            message_tokens=chat_record_dict.get('message_tokens'),
            answer_tokens=chat_record_dict.get('answer_tokens'),
            const=chat_record_dict.get('const'),
            details=chat_record_dict.get('details'),
            improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
            run_time=chat_record_dict.get('run_time'),
            index=chat_record_dict.get('index'))

Optimization Suggestions:

  1. Avoid Recursively Serializing Objects:
    In the case of nested objects (like chat_record_list), avoid serializing them recursively within each object by using a custom function. This can be done by modifying the serialization logic:

    def chat_record_to_dict(self, chat_record):
        return {
            'id': chat_record.id,
            # Serialize other fields here...
            'run_time': chat_record.run_time.strftime('%Y-%M-%d %H:%M:%S')  # Example format conversion
        }
    
    def serialize_data(self):
        data = self.to_dict()
        data['chat_record_list'] = [
            self.chat_record_to_dict(record) for record in self.chat_record_list
        ]
        return data
  2. Use Data Classes Instead of Regular Classes:
    Consider converting ChatRecord and ChatInfo into Pydantic models or FastAPI schemas for better type safety and additional features provided by these libraries.

  3. Consider Caching Strategy:
    If caching involves multiple layers (e.g., chat_info -> chat_records), ensure that the cache keys are appropriate and do not lead to conflicts between different caches.

By addressing these points, you can enhance both readability and performance of your code.

2 changes: 1 addition & 1 deletion apps/common/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def to_string(self):
value = json.dumps(self.to_dict())
authentication = encrypt(value)
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2)
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2)
return authentication

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions apps/common/constants/cache_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Cache_Version(Enum):
# 对话
CHAT = "CHAT", lambda key: key

CHAT_INFO = "CHAT_INFO", lambda key: key

CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key

# 应用API KEY
Expand Down
Loading