Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 10 additions & 76 deletions agentops/llms/providers/ai21.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
class AI21Provider(BaseProvider):
original_create = None
original_create_async = None
original_answer = None
original_answer_async = None

def __init__(self, client):
super().__init__(client)
Expand All @@ -28,11 +26,8 @@
from ai21.stream.stream import Stream
from ai21.stream.async_stream import AsyncStream
from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk
from ai21.models.chat.chat_completion_response import ChatCompletionResponse
from ai21.models.responses.answer_response import AnswerResponse

llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs)

if session is not None:
llm_event.session_id = session.session_id
Expand Down Expand Up @@ -108,27 +103,15 @@

# Handle object responses
try:
if isinstance(response, ChatCompletionResponse):
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
llm_event.prompt = [message.model_dump() for message in kwargs["messages"]]
llm_event.prompt_tokens = response.usage.prompt_tokens
llm_event.completion = response.choices[0].message.model_dump()
llm_event.completion_tokens = response.usage.completion_tokens
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

elif isinstance(response, AnswerResponse):
action_event.returns = response
action_event.agent_id = check_call_stack_for_agent_id()
action_event.action_type = "Contextual Answers"
action_event.logs = [
{"context": kwargs["context"], "question": kwargs["question"]},
response.model_dump() if response.model_dump() else None,
]
action_event.end_timestamp = get_ISO_time()
self._safe_record(session, action_event)
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
llm_event.prompt = [message.model_dump() for message in kwargs["messages"]]
llm_event.prompt_tokens = response.usage.prompt_tokens
llm_event.completion = response.choices[0].message.model_dump()
llm_event.completion_tokens = response.usage.completion_tokens
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

Check warning on line 114 in agentops/llms/providers/ai21.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/ai21.py#L106-L114

Added lines #L106 - L114 were not covered by tests

except Exception as e:
self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
Expand All @@ -145,8 +128,6 @@
def override(self):
self._override_completion()
self._override_completion_async()
self._override_answer()
self._override_answer_async()

def _override_completion(self):
from ai21.clients.studio.resources.chat import ChatCompletions
Expand Down Expand Up @@ -184,59 +165,12 @@
# Override the original method with the patched one
AsyncChatCompletions.create = patched_function

def _override_answer(self):
from ai21.clients.studio.resources.studio_answer import StudioAnswer

global original_answer
original_answer = StudioAnswer.create

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_answer(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

StudioAnswer.create = patched_function

def _override_answer_async(self):
from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer

global original_answer_async
original_answer_async = AsyncStudioAnswer.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_answer_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

AsyncStudioAnswer.create = patched_function

def undo_override(self):
if (
self.original_create is not None
and self.original_create_async is not None
and self.original_answer is not None
and self.original_answer_async is not None
):
if self.original_create is not None and self.original_create_async is not None:

Check warning on line 169 in agentops/llms/providers/ai21.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/ai21.py#L169

Added line #L169 was not covered by tests
from ai21.clients.studio.resources.chat import (
ChatCompletions,
AsyncChatCompletions,
)
from ai21.clients.studio.resources.studio_answer import (
StudioAnswer,
AsyncStudioAnswer,
)

ChatCompletions.create = self.original_create
AsyncChatCompletions.create = self.original_create_async
StudioAnswer.create = self.original_answer
AsyncStudioAnswer.create = self.original_answer_async
Loading