Skip to content

Commit 662cfcc

Browse files
authored
feat: MelleaSession.register for functional interface and MelleaSession.powerup for dynamic mixin (register all methods in a class) (#224)
* refactor: renamed mellea.stdlib.funcs -> mellea.stdlib.functional * feat: MelleaSession.powerup * test: testing powerup
1 parent 9d12458 commit 662cfcc

File tree

6 files changed

+24
-5
lines changed

6 files changed

+24
-5
lines changed
File renamed without changes.

mellea/stdlib/sampling/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import tqdm
77

8+
import mellea.stdlib.functional as mfuncs
89
from mellea.backends import Backend, BaseModelSubclass
910
from mellea.helpers.fancy_logger import FancyLogger
10-
from mellea.stdlib import funcs as mfuncs
1111
from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk
1212
from mellea.stdlib.chat import Message
1313
from mellea.stdlib.instruction import Instruction

mellea/stdlib/sampling/best_of_n.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import tqdm
66

7+
import mellea.stdlib.functional as mfuncs
78
from mellea.backends import Backend, BaseModelSubclass
89
from mellea.helpers.async_helpers import wait_for_all_mots
910
from mellea.helpers.fancy_logger import FancyLogger
10-
from mellea.stdlib import funcs as mfuncs
1111
from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk
1212
from mellea.stdlib.instruction import Instruction
1313
from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult

mellea/stdlib/session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from __future__ import annotations
44

55
import contextvars
6+
import inspect
67
from copy import copy
78
from typing import Any, Literal, overload
89

910
from PIL import Image as PILImage
1011

11-
import mellea.stdlib.funcs as mfuncs
12+
import mellea.stdlib.functional as mfuncs
1213
from mellea.backends import Backend, BaseModelSubclass
1314
from mellea.backends.model_ids import (
1415
IBM_GRANITE_3_3_8B,
@@ -804,6 +805,12 @@ async def atransform(
804805
self.ctx = context
805806
return result
806807

808+
@classmethod
809+
def powerup(cls, powerup_cls: type):
810+
"""Appends methods in a class object `powerup_cls` to MelleaSession."""
811+
for name, fn in inspect.getmembers(powerup_cls, predicate=inspect.isfunction):
812+
setattr(cls, name, fn)
813+
807814
# ###############################
808815
# Convenience functions
809816
# ###############################

test/stdlib_basics/test_funcs.py renamed to test/stdlib_basics/test_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mellea.backends.types import ModelOption
44
from mellea.stdlib.base import ModelOutputThunk
55
from mellea.stdlib.chat import Message
6-
from mellea.stdlib.funcs import instruct, aact, avalidate, ainstruct
6+
from mellea.stdlib.functional import instruct, aact, avalidate, ainstruct
77
from mellea.stdlib.requirement import req
88
from mellea.stdlib.session import start_session
99

test/stdlib_basics/test_session.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mellea.backends.types import ModelOption
88
from mellea.stdlib.base import ChatContext, ModelOutputThunk
99
from mellea.stdlib.chat import Message
10-
from mellea.stdlib.session import start_session
10+
from mellea.stdlib.session import start_session, MelleaSession
1111

1212

1313
# We edit the context type in the async tests below. Don't change the scope here.
@@ -134,5 +134,17 @@ def test_session_copy_with_context_ops(m_session):
134134
assert m2.ctx.previous_node.previous_node is m_session.ctx
135135

136136

137+
class TestPowerup:
138+
def hello(m:MelleaSession):
139+
return "hello"
140+
141+
142+
def test_powerup(m_session):
143+
144+
MelleaSession.powerup(TestPowerup)
145+
146+
assert "hello" == m_session.hello()
147+
148+
137149
if __name__ == "__main__":
138150
pytest.main([__file__])

0 commit comments

Comments
 (0)