|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import contextvars |
| 6 | +import functools |
| 7 | +import inspect |
| 8 | +from collections.abc import Callable |
6 | 9 | from copy import copy |
7 | 10 | from typing import Any, Literal, overload |
8 | 11 |
|
@@ -236,6 +239,43 @@ def cleanup(self) -> None: |
236 | 239 | if hasattr(self.backend, "close"): |
237 | 240 | self.backend.close() # type: ignore |
238 | 241 |
|
| 242 | + @classmethod |
| 243 | + def register(cls, fn: Callable, set_context: bool = True): |
| 244 | + """Registers fn as a new method to MelleaSession. |
| 245 | +
|
| 246 | + The function fn must accept `backend` and `context` arguments. |
| 247 | + """ |
| 248 | + |
| 249 | + def postprocess(self, r): |
| 250 | + if set_context: |
| 251 | + if isinstance(r, SamplingResult): |
| 252 | + self.ctx = r.result_ctx |
| 253 | + return r |
| 254 | + else: |
| 255 | + result, context = r |
| 256 | + self.ctx = context |
| 257 | + return result |
| 258 | + else: |
| 259 | + return r |
| 260 | + |
| 261 | + if inspect.iscoroutinefunction(fn): |
| 262 | + |
| 263 | + @functools.wraps(fn) |
| 264 | + async def wrapper(self, *args, **kwargs): |
| 265 | + return postprocess( |
| 266 | + self, |
| 267 | + await fn(backend=self.backend, context=self.ctx, *args, **kwargs), |
| 268 | + ) |
| 269 | + else: |
| 270 | + |
| 271 | + @functools.wraps(fn) |
| 272 | + def wrapper(self, *args, **kwargs): |
| 273 | + return postprocess( |
| 274 | + self, fn(backend=self.backend, context=self.ctx, *args, **kwargs) |
| 275 | + ) |
| 276 | + |
| 277 | + setattr(cls, fn.__name__, wrapper) |
| 278 | + |
239 | 279 | @overload |
240 | 280 | def act( |
241 | 281 | self, |
|
0 commit comments