Skip to content

Commit 57cca97

Browse files
committed
Refactor and bug fixes.
Deletes the stdlib.span package and moves simplecomponent into base. Fixes a big in call to gather (should be *list not list)
1 parent 8cdcae5 commit 57cca97

File tree

8 files changed

+60
-79
lines changed

8 files changed

+60
-79
lines changed

docs/examples/melp/lazy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
2-
from mellea.stdlib.span import Span, SimpleComponent
3-
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk
2+
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk, SimpleComponent
43
from mellea.backends import Backend
54
from mellea.backends.ollama import OllamaModelBackend
65

docs/examples/melp/lazy_fib.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
2-
from mellea.stdlib.span import Span, SimpleComponent
3-
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk
2+
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk, SimpleComponent
43
from mellea.stdlib.requirement import Requirement
54
from mellea.backends import Backend
65
from mellea.backends.ollama import OllamaModelBackend

docs/examples/melp/lazy_fib_sample.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
2-
from mellea.stdlib.span import Span, SimpleComponent
3-
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk
2+
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk, SimpleComponent
43
from mellea.stdlib.requirement import Requirement
54
from mellea.backends import Backend
65
from mellea.backends.ollama import OllamaModelBackend

docs/examples/melp/states.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import mellea
2-
from mellea.stdlib.base import CBlock, Context, SimpleContext
3-
from mellea.stdlib.span import Span, SimpleComponent
1+
from mellea.stdlib.base import SimpleContext, Context, CBlock, SimpleComponent
42
from mellea.backends import Backend
53
from mellea.backends.ollama import OllamaModelBackend
64
import asyncio

mellea/backends/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputT
9898
return []
9999
case Component():
100100
parts_walk = [generate_walk(p) for p in c.parts()]
101-
return itertools.chain.from_iterable(parts_walk) # aka flatten
101+
return list(itertools.chain.from_iterable(parts_walk)) # aka flatten

mellea/backends/ollama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,9 @@ async def generate_from_chat_context(
296296
RuntimeError: If not called from a thread with a running event loop.
297297
"""
298298
# Start by awaiting any necessary computation.
299-
_to_compute = generate_walk(action)
300-
await asyncio.gather([x.avalue() for x in _to_compute])
299+
_to_compute = list(generate_walk(action))
300+
coroutines = [x.avalue() for x in _to_compute]
301+
await asyncio.gather(*coroutines)
301302
FancyLogger.get_logger().info(
302303
f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots."
303304
)
@@ -419,7 +420,8 @@ async def generate_from_raw(
419420
_to_compute = []
420421
for act in actions:
421422
_to_compute.extend(generate_walk(act))
422-
await asyncio.gather([x.avalue() for x in _to_compute])
423+
coroutines = [x.avalue() for x in _to_compute]
424+
await asyncio.gather(*coroutines)
423425

424426
prompts = [self.formatter.print(action) for action in actions]
425427

mellea/stdlib/base.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,3 +656,53 @@ class ModelToolCall:
656656
def call_func(self) -> Any:
657657
"""A helper function for calling the function/tool represented by this object."""
658658
return self.func(**self.args)
659+
660+
661+
class SimpleComponent(Component):
662+
"""A Component that is make up of named spans."""
663+
664+
def __init__(self, **kwargs):
665+
"""Initialized a simple component of the constructor's kwargs."""
666+
for key in kwargs.keys():
667+
if type(kwargs[key]) is str:
668+
kwargs[key] = CBlock(value=kwargs[key])
669+
self._kwargs_type_check(kwargs)
670+
self._kwargs = kwargs
671+
672+
def parts(self):
673+
"""Returns the values of the kwargs."""
674+
return list(self._kwargs.values())
675+
676+
def _kwargs_type_check(self, kwargs):
677+
for key in kwargs.keys():
678+
value = kwargs[key]
679+
assert issubclass(type(value), Component) or issubclass(
680+
type(value), CBlock
681+
), f"Expected span but found {type(value)} of value: {value}"
682+
assert type(key) is str
683+
return True
684+
685+
@staticmethod
686+
def make_simple_string(kwargs):
687+
"""Uses <|key|>value</|key|> to represent a simple component."""
688+
return "\n".join(
689+
[f"<|{key}|>{value}</|{key}|>" for (key, value) in kwargs.items()]
690+
)
691+
692+
@staticmethod
693+
def make_json_string(kwargs):
694+
"""Uses json."""
695+
str_args = dict()
696+
for key in kwargs.keys():
697+
match kwargs[key]:
698+
case ModelOutputThunk() | CBlock():
699+
str_args[key] = kwargs[key].value
700+
case Component():
701+
str_args[key] = kwargs[key].format_for_llm()
702+
import json
703+
704+
return json.dumps(str_args)
705+
706+
def format_for_llm(self):
707+
"""Uses a string rep."""
708+
return SimpleComponent.make_json_string(self._kwargs)

mellea/stdlib/span/__init__.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)