-
Notifications
You must be signed in to change notification settings - Fork 2.6k
perf: Memory optimization #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,25 +234,62 @@ def run_block(self, language='zh'): | |
| 非流式响应 | ||
| @return: 结果 | ||
| """ | ||
| self.run_chain_async(None, None, language) | ||
| while self.is_run(): | ||
| pass | ||
| details = self.get_runtime_details() | ||
| message_tokens = sum([row.get('message_tokens') for row in details.values() if | ||
| 'message_tokens' in row and row.get('message_tokens') is not None]) | ||
| answer_tokens = sum([row.get('answer_tokens') for row in details.values() if | ||
| 'answer_tokens' in row and row.get('answer_tokens') is not None]) | ||
| answer_text_list = self.get_answer_text_list() | ||
| answer_text = '\n\n'.join( | ||
| '\n\n'.join([a.get('content') for a in answer]) for answer in | ||
| answer_text_list) | ||
| answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) | ||
| self.work_flow_post_handler.handler(self) | ||
| return self.base_to_response.to_block_response(self.params['chat_id'], | ||
| self.params['chat_record_id'], answer_text, True | ||
| , message_tokens, answer_tokens, | ||
| _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
| other_params={'answer_list': answer_list}) | ||
| try: | ||
| self.run_chain_async(None, None, language) | ||
| while self.is_run(): | ||
| pass | ||
| details = self.get_runtime_details() | ||
| message_tokens = sum([row.get('message_tokens') for row in details.values() if | ||
| 'message_tokens' in row and row.get('message_tokens') is not None]) | ||
| answer_tokens = sum([row.get('answer_tokens') for row in details.values() if | ||
| 'answer_tokens' in row and row.get('answer_tokens') is not None]) | ||
| answer_text_list = self.get_answer_text_list() | ||
| answer_text = '\n\n'.join( | ||
| '\n\n'.join([a.get('content') for a in answer]) for answer in | ||
| answer_text_list) | ||
| answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) | ||
| self.work_flow_post_handler.handler(self) | ||
|
|
||
| res = self.base_to_response.to_block_response(self.params['chat_id'], | ||
| self.params['chat_record_id'], answer_text, True | ||
| , message_tokens, answer_tokens, | ||
| _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
| other_params={'answer_list': answer_list}) | ||
| finally: | ||
| self._cleanup() | ||
| return res | ||
|
|
||
| def _cleanup(self): | ||
| """清理所有对象引用""" | ||
| # 清理列表 | ||
| self.future_list.clear() | ||
| self.field_list.clear() | ||
| self.global_field_list.clear() | ||
| self.chat_field_list.clear() | ||
| self.image_list.clear() | ||
| self.video_list.clear() | ||
| self.document_list.clear() | ||
| self.audio_list.clear() | ||
| self.other_list.clear() | ||
| if hasattr(self, 'node_context'): | ||
| self.node_context.clear() | ||
|
|
||
| # 清理字典 | ||
| self.context.clear() | ||
| self.chat_context.clear() | ||
| self.form_data.clear() | ||
|
|
||
| # 清理对象引用 | ||
| self.node_chunk_manage = None | ||
| self.work_flow_post_handler = None | ||
| self.flow = None | ||
| self.start_node = None | ||
| self.current_node = None | ||
| self.current_result = None | ||
| self.chat_record = None | ||
| self.base_to_response = None | ||
| self.params = None | ||
| self.lock = None | ||
|
|
||
| def run_stream(self, current_node, node_result_future, language='zh'): | ||
| """ | ||
|
|
@@ -307,6 +344,7 @@ def await_result(self): | |
| '', | ||
| [], | ||
| '', True, message_tokens, answer_tokens, {}) | ||
| self._cleanup() | ||
|
|
||
| def run_chain_async(self, current_node, node_result_future, language='zh'): | ||
| future = executor.submit(self.run_chain_manage, current_node, node_result_future, language) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The proposed code has several improvements:
These changes improve the robustness of the code by handling exceptions and reducing resource pollution due to stale references. Suggested Cleanup CodeTo further streamline the cleanup process, you could also consider defining an abstract base class ( class NodeContext(metaclass=abc.ABCMeta):
def clear(self): pass
# Example implementation using threading.local for thread-specific contexts
import threading
current_context = threading.local()
def setup_thread_specific_context():
global current_context
current_context.node_chunk_manage = None
current_context.work_flow_post_handler = None
current_context.flow = None
current_context.start_node = None
current_context.current_node = None
current_context.current_result = None
current_context.chat_record = None
current_context.base_to_response = None
current_context.params = None
current_context.lock = None
setup_thread_specific_context()This approach centralizes context management, making it easier to handle various stateful objects across threads or asynchronous tasks. Final Cleaned-up Code Snippet: @staticmethod
def start_task(flow_id: str) -> Task:
flow = Flow.objects.filter(id=flow_id).first()
if not flow:
raise ValueError(f"Flow with id {flow_id} does not exist.")
return cls(flow)
async def run_async(self, async_loop=None, language="zh"):
try:
start = time.time()
self._start_time = start
await self._set_start_info(async_loop, "run_async", False)
# Additional steps...
while True:
await asyncio.sleep(0.1)
if not self.running and all(not f.done() for f in getattr(self, 'future', [])):
break
result_text, info_dict = [], []
# Handle intermediate messages (e.g., step-by-step outputs)
# Generate final response
resp = self._to_api_resp("final_response", "", "", {}, "")
# Await results from background tasks
await asyncio.gather(*[f.result() for f in getattr(self, 'future', [])])
print(f'Total execution time took {time.time() - start:.2f}s')
return resp
finally:
self._end_time = round(time.time(), 4)
total_cost = int((self._execution_total / len(self.output_queue)) * 1000) if self.output_queue else 0
logger.info(f"{task_type.upper()}, cost:{total_cost}ms")
task_manager.save_task(self.task)
setattr(self.task, '_id', self.id)
self.save()
self._clear_all_objects()
def _clear_all_objects(self):
if hasattr(current_context, 'result'):
del current_context.resultNote: The above cleaned-up code assumes the existence of relevant classes and modules such as |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,6 @@ | |
| @date:2025/6/9 13:42 | ||
| @desc: | ||
| """ | ||
| from datetime import datetime | ||
| from typing import List | ||
|
|
||
| from django.core.cache import cache | ||
|
|
@@ -226,10 +225,66 @@ def append_chat_record(self, chat_record: ChatRecord): | |
| chat_record.save() | ||
| ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat() | ||
|
|
||
| def to_dict(self): | ||
|
|
||
| return { | ||
| 'chat_id': self.chat_id, | ||
| 'chat_user_id': self.chat_user_id, | ||
| 'chat_user_type': self.chat_user_type, | ||
| 'knowledge_id_list': self.knowledge_id_list, | ||
| 'exclude_document_id_list': self.exclude_document_id_list, | ||
| 'application_id': self.application_id, | ||
| 'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list], | ||
| 'debug': self.debug | ||
| } | ||
|
|
||
| def chat_record_to_map(self, chat_record): | ||
| return {'id': chat_record.id, | ||
| 'chat_id': chat_record.chat_id, | ||
| 'vote_status': chat_record.vote_status, | ||
| 'problem_text': chat_record.problem_text, | ||
| 'answer_text': chat_record.answer_text, | ||
| 'answer_text_list': chat_record.answer_text_list, | ||
| 'message_tokens': chat_record.message_tokens, | ||
| 'answer_tokens': chat_record.answer_tokens, | ||
| 'const': chat_record.const, | ||
| 'details': chat_record.details, | ||
| 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list, | ||
| 'run_time': chat_record.run_time, | ||
| 'index': chat_record.index} | ||
|
|
||
| @staticmethod | ||
| def map_to_chat_record(chat_record_dict): | ||
| ChatRecord(id=chat_record_dict.get('id'), | ||
| chat_id=chat_record_dict.get('chat_id'), | ||
| vote_status=chat_record_dict.get('vote_status'), | ||
| problem_text=chat_record_dict.get('problem_text'), | ||
| answer_text=chat_record_dict.get('answer_text'), | ||
| answer_text_list=chat_record_dict.get('answer_text_list'), | ||
| message_tokens=chat_record_dict.get('message_tokens'), | ||
| answer_tokens=chat_record_dict.get('answer_tokens'), | ||
| const=chat_record_dict.get('const'), | ||
| details=chat_record_dict.get('details'), | ||
| improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'), | ||
| run_time=chat_record_dict.get('run_time'), | ||
| index=chat_record_dict.get('index'), ) | ||
|
|
||
| def set_cache(self): | ||
| cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), | ||
| cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(), | ||
| version=Cache_Version.CHAT_INFO.get_version(), | ||
| timeout=60 * 30) | ||
|
|
||
| @staticmethod | ||
| def map_to_chat_info(chat_info_dict): | ||
| return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'), | ||
| chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'), | ||
| chat_info_dict.get('exclude_document_id_list'), | ||
| chat_info_dict.get('application_id'), | ||
| [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')]) | ||
|
|
||
| @staticmethod | ||
| def get_cache(chat_id): | ||
| return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version()) | ||
| chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version()) | ||
| if chat_info_dict: | ||
| return ChatInfo.map_to_chat_info(chat_info_dict) | ||
| return None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code looks generally clean and follows Python conventions. Here are some minor points for consideration: Potential Issues:
Optimization Suggestions:
By addressing these points, you can enhance both readability and performance of your code. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has several improvements and optimizations:
Improvements/Enhancements:
Global Event Loop Management:
Error Handling on Exit:
finallyblock). The exception details are still logged.Concurrency Improvements:
Queueing Results:
queue.Queue()instead of returning data synchronously. This keeps memory usage more efficient while maintaining high throughput.Optimizations:
Avoid Repeated Logging:
Thread Safety Enhancements:
Memory Efficiency:
Logging Contexts:
By addressing these points, you improve both performance and maintainability of the codebase. Always ensure compatibility across the specified cutoff date and test thoroughly before deployment.