diff --git a/.fernignore b/.fernignore index b3da066..ce1ae14 100644 --- a/.fernignore +++ b/.fernignore @@ -4,4 +4,5 @@ README.md src/multion/client.py -src/multion/sessions/wrapped_client.py \ No newline at end of file +src/multion/sessions/wrapped_client.py +src/multion/wrappers.py \ No newline at end of file diff --git a/src/multion/client.py b/src/multion/client.py index 2713d51..ba65531 100644 --- a/src/multion/client.py +++ b/src/multion/client.py @@ -8,6 +8,8 @@ import agentops import os +from .wrappers import wraps_function + # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) @@ -26,8 +28,6 @@ class MultiOn(BaseMultiOn): - api_key: typing.Optional[str]. - - agentops_api_key: typing.Optional[str]. - - timeout: typing.Optional[float]. The timeout to be used, in seconds, for requests by default the timeout is 60 seconds, unless a custom httpx client is used, in which case a default is not set. - follow_redirects: typing.Optional[bool]. Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in. @@ -38,12 +38,11 @@ class MultiOn(BaseMultiOn): client = MultiOn( api_key="YOUR_API_KEY", - agentops_api_key="YOUR_AGENTOPS_API_KEY", ) """ - def __init__(self, *args, **kwargs): - agentops_api_key = kwargs.pop("agentops_api_key", None) + @wraps_function(BaseMultiOn.__init__) + def __init__(self, *args, agentops_api_key: typing.Optional[str] = os.getenv("AGENTOPS_API_KEY"), **kwargs): super().__init__(*args, **kwargs) self.sessions = WrappedSessionsClient(client_wrapper=self._client_wrapper) if agentops_api_key is not None: @@ -54,6 +53,7 @@ def __init__(self, *args, **kwargs): ) @agentops.record_function(event_name="browse") + @wraps_function(BaseMultiOn.browse) def browse(self, *args, **kwargs): agentops.start_session(tags=["multion-sdk"]) return super().browse(*args, **kwargs) @@ -88,6 +88,7 @@ class AsyncMultiOn(AsyncBaseMultiOn): ) """ + @wraps_function(AsyncBaseMultiOn.__init__) def __init__(self, *args, **kwargs): agentops_api_key = kwargs.pop("agentops_api_key", None) super().__init__(*args, **kwargs) @@ -100,6 +101,7 @@ def __init__(self, *args, **kwargs): ) @agentops.record_function(event_name="browse") + @wraps_function(AsyncBaseMultiOn.browse) async def browse(self, *args, **kwargs): agentops.start_session(tags=["multion-sdk"]) return super().browse(*args, **kwargs) diff --git a/src/multion/sessions/wrapped_client.py b/src/multion/sessions/wrapped_client.py index ffed003..83cd5c9 100644 --- a/src/multion/sessions/wrapped_client.py +++ b/src/multion/sessions/wrapped_client.py @@ -12,11 +12,14 @@ # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) +from ..wrappers import wraps_function class WrappedSessionsClient(SessionsClient): + @wraps_function(SessionsClient.__init__) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @wraps_function(SessionsClient.create) def create(self, *args, **kwargs) -> SessionCreated: agentops.start_session(tags=["multion-sdk"]) try: @@ -26,6 +29,7 @@ def create(self, *args, **kwargs) -> SessionCreated: agentops.record(error_event) raise e + @wraps_function(SessionsClient.step_stream) def step_stream(self, *args, **kwargs) -> typing.Iterator[SessionStepStreamChunk]: action_event = ActionEvent(action_type="step_stream", params=kwargs) action_event.returns = "" @@ -50,8 +54,9 @@ def generator(): yield chunk return generator() - + @agentops.record_function(event_name="step") + @wraps_function(SessionsClient.step) def step(self, *args, **kwargs) -> SessionStepSuccess: llm_event = LLMEvent() step_response = super().step(*args, **kwargs) @@ -59,21 +64,23 @@ def step(self, *args, **kwargs) -> SessionStepSuccess: agentops.record(llm_event) return step_response + @wraps_function(SessionsClient.close) def close(self, *args, **kwargs) -> SessionsCloseResponse: close_response = super().close(*args, **kwargs) agentops.end_session("Success") return close_response @agentops.record_function(event_name="retrieve") + @wraps_function(SessionsClient.retrieve) def retrieve(self, *args, **kwargs) -> RetrieveOutput: return super().retrieve(*args, **kwargs) - -# TODO: Test async class WrappedAsyncSessionsClient(AsyncSessionsClient): + @wraps_function(AsyncSessionsClient.__init__) async def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @wraps_function(AsyncSessionsClient.__init__) async def create(self, *args, **kwargs) -> SessionCreated: agentops.start_session(tags=["multion-sdk"]) try: @@ -84,6 +91,7 @@ async def create(self, *args, **kwargs) -> SessionCreated: raise e @agentops.record_function(event_name="step_stream") + @wraps_function(AsyncSessionsClient.step_stream) async def step_stream( self, *args, **kwargs ) -> typing.Iterator[SessionStepStreamChunk]: @@ -112,6 +120,7 @@ def generator(): return generator() @agentops.record_function(event_name="step") + @wraps_function(AsyncSessionsClient.step) async def step(self, *args, **kwargs) -> SessionStepSuccess: llm_event = LLMEvent() step_response = super().step(*args, **kwargs) @@ -119,11 +128,13 @@ async def step(self, *args, **kwargs) -> SessionStepSuccess: agentops.record(llm_event) return step_response + @wraps_function(AsyncSessionsClient.close) async def close(self, *args, **kwargs) -> SessionsCloseResponse: close_response = super().close(*args, **kwargs) agentops.end_session("Success") return close_response + @wraps_function(AsyncSessionsClient.retrieve) @agentops.record_function(event_name="retrieve") async def retrieve(*args, **kwargs) -> RetrieveOutput: return super().retrieve(*args, **kwargs) diff --git a/src/multion/wrappers.py b/src/multion/wrappers.py new file mode 100644 index 0000000..67d35a3 --- /dev/null +++ b/src/multion/wrappers.py @@ -0,0 +1,20 @@ +import typing +from typing import Callable +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") +T = typing.TypeVar("T") + + +def wraps_function( + _fun: Callable[P, T] +) -> Callable[[Callable[Concatenate[Callable[P, T], P], T]], Callable[P, T]]: + def decorator( + wrapped_fun: Callable[Concatenate[Callable[P, T], P], T] + ) -> Callable[P, T]: + def decorated(self, *args: P.args, **kwargs: P.kwargs) -> T: + return wrapped_fun(self, *args, **kwargs) + + return decorated + + return decorator