|
| 1 | +# Copyright (c) Saga Inc. |
| 2 | +# Distributed under the terms of the GNU Affero General Public License v3.0 License. |
| 3 | + |
| 4 | +from typing import List, Literal, Union |
| 5 | +from openai.types.chat import ChatCompletionMessageParam |
| 6 | +from mito_ai.completions.models import ScratchpadResultMetadata, MessageType, ResponseFormatInfo, AgentResponse |
| 7 | +from mito_ai.completions.prompt_builders.scratchpad_result_prompt import create_scratchpad_result_prompt |
| 8 | +from mito_ai.completions.providers import OpenAIProvider |
| 9 | +from mito_ai.completions.message_history import GlobalMessageHistory |
| 10 | +from mito_ai.completions.completion_handlers.completion_handler import CompletionHandler |
| 11 | +from mito_ai.completions.completion_handlers.utils import append_agent_system_message, create_ai_optimized_message |
| 12 | + |
| 13 | +__all__ = ["get_scratchpad_result_completion"] |
| 14 | + |
| 15 | +class ScratchpadResultHandler(CompletionHandler[ScratchpadResultMetadata]): |
| 16 | + """Handler for scratchpad result completions.""" |
| 17 | + |
| 18 | + @staticmethod |
| 19 | + async def get_completion( |
| 20 | + metadata: ScratchpadResultMetadata, |
| 21 | + provider: OpenAIProvider, |
| 22 | + message_history: GlobalMessageHistory, |
| 23 | + model: str |
| 24 | + ) -> str: |
| 25 | + """Get a scratchpad result completion from the AI provider.""" |
| 26 | + |
| 27 | + if metadata.index is not None: |
| 28 | + message_history.truncate_histories( |
| 29 | + thread_id=metadata.threadId, |
| 30 | + index=metadata.index |
| 31 | + ) |
| 32 | + |
| 33 | + # Add the system message if it doesn't already exist |
| 34 | + await append_agent_system_message(message_history, model, provider, metadata.threadId, True) |
| 35 | + |
| 36 | + # Create the prompt |
| 37 | + prompt = create_scratchpad_result_prompt(metadata) |
| 38 | + display_prompt = "" |
| 39 | + |
| 40 | + # Add the prompt to the message history |
| 41 | + new_ai_optimized_message = create_ai_optimized_message(prompt, None, None) |
| 42 | + new_display_optimized_message: ChatCompletionMessageParam = {"role": "user", "content": display_prompt} |
| 43 | + |
| 44 | + await message_history.append_message(new_ai_optimized_message, new_display_optimized_message, model, provider, metadata.threadId) |
| 45 | + |
| 46 | + # Get the completion |
| 47 | + completion = await provider.request_completions( |
| 48 | + messages=message_history.get_ai_optimized_history(metadata.threadId), |
| 49 | + model=model, |
| 50 | + response_format_info=ResponseFormatInfo( |
| 51 | + name='agent_response', |
| 52 | + format=AgentResponse |
| 53 | + ), |
| 54 | + message_type=MessageType.AGENT_SCRATCHPAD_RESULT, |
| 55 | + user_input="", |
| 56 | + thread_id=metadata.threadId |
| 57 | + ) |
| 58 | + |
| 59 | + ai_response_message: ChatCompletionMessageParam = {"role": "assistant", "content": completion} |
| 60 | + |
| 61 | + await message_history.append_message(ai_response_message, ai_response_message, model, provider, metadata.threadId) |
| 62 | + |
| 63 | + return completion |
| 64 | + |
| 65 | +# Use the static method directly |
| 66 | +get_scratchpad_result_completion = ScratchpadResultHandler.get_completion |
0 commit comments