Skip to content

Commit 78b1b32

Browse files
committed
feat: MelleaSession.register
1 parent 6cd5249 commit 78b1b32

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

mellea/stdlib/session.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from __future__ import annotations
44

55
import contextvars
6+
import functools
7+
import inspect
8+
from collections.abc import Callable
69
from copy import copy
710
from typing import Any, Literal, overload
811

@@ -236,6 +239,43 @@ def cleanup(self) -> None:
236239
if hasattr(self.backend, "close"):
237240
self.backend.close() # type: ignore
238241

242+
@classmethod
243+
def register(cls, fn: Callable, set_context: bool = True):
244+
"""Registers fn as a new method to MelleaSession.
245+
246+
The function fn must accept `backend` and `context` arguments.
247+
"""
248+
249+
def postprocess(self, r):
250+
if set_context:
251+
if isinstance(r, SamplingResult):
252+
self.ctx = r.result_ctx
253+
return r
254+
else:
255+
result, context = r
256+
self.ctx = context
257+
return result
258+
else:
259+
return r
260+
261+
if inspect.iscoroutinefunction(fn):
262+
263+
@functools.wraps(fn)
264+
async def wrapper(self, *args, **kwargs):
265+
return postprocess(
266+
self,
267+
await fn(backend=self.backend, context=self.ctx, *args, **kwargs),
268+
)
269+
else:
270+
271+
@functools.wraps(fn)
272+
def wrapper(self, *args, **kwargs):
273+
return postprocess(
274+
self, fn(backend=self.backend, context=self.ctx, *args, **kwargs)
275+
)
276+
277+
setattr(cls, fn.__name__, wrapper)
278+
239279
@overload
240280
def act(
241281
self,

0 commit comments

Comments
 (0)