|
20 | 20 | from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage |
21 | 21 | from core.model_runtime.entities.message_entities import ( |
22 | 22 | AssistantPromptMessage, |
| 23 | + PromptMessageContent, |
23 | 24 | PromptMessageRole, |
24 | 25 | SystemPromptMessage, |
25 | 26 | UserPromptMessage, |
@@ -828,14 +829,14 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: |
828 | 829 | } |
829 | 830 |
|
830 | 831 |
|
831 | | -def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): |
| 832 | +def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): |
832 | 833 | match role: |
833 | 834 | case PromptMessageRole.USER: |
834 | | - return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) |
| 835 | + return UserPromptMessage(content=contents) |
835 | 836 | case PromptMessageRole.ASSISTANT: |
836 | | - return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) |
| 837 | + return AssistantPromptMessage(content=contents) |
837 | 838 | case PromptMessageRole.SYSTEM: |
838 | | - return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) |
| 839 | + return SystemPromptMessage(content=contents) |
839 | 840 | raise NotImplementedError(f"Role {role} is not supported") |
840 | 841 |
|
841 | 842 |
|
@@ -877,7 +878,9 @@ def _handle_list_messages( |
877 | 878 | jinjia2_variables=jinja2_variables, |
878 | 879 | variable_pool=variable_pool, |
879 | 880 | ) |
880 | | - prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) |
| 881 | + prompt_message = _combine_message_content_with_role( |
| 882 | + contents=[TextPromptMessageContent(data=result_text)], role=message.role |
| 883 | + ) |
881 | 884 | prompt_messages.append(prompt_message) |
882 | 885 | else: |
883 | 886 | # Get segment group from basic message |
@@ -908,12 +911,14 @@ def _handle_list_messages( |
908 | 911 | # Create message with text from all segments |
909 | 912 | plain_text = segment_group.text |
910 | 913 | if plain_text: |
911 | | - prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) |
| 914 | + prompt_message = _combine_message_content_with_role( |
| 915 | + contents=[TextPromptMessageContent(data=plain_text)], role=message.role |
| 916 | + ) |
912 | 917 | prompt_messages.append(prompt_message) |
913 | 918 |
|
914 | 919 | if file_contents: |
915 | 920 | # Create message with image contents |
916 | | - prompt_message = UserPromptMessage(content=file_contents) |
| 921 | + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) |
917 | 922 | prompt_messages.append(prompt_message) |
918 | 923 |
|
919 | 924 | return prompt_messages |
@@ -1018,6 +1023,8 @@ def _handle_completion_template( |
1018 | 1023 | else: |
1019 | 1024 | template_text = template.text |
1020 | 1025 | result_text = variable_pool.convert_template(template_text).text |
1021 | | - prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) |
| 1026 | + prompt_message = _combine_message_content_with_role( |
| 1027 | + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER |
| 1028 | + ) |
1022 | 1029 | prompt_messages.append(prompt_message) |
1023 | 1030 | return prompt_messages |
0 commit comments