Skip to content

Commit caca2a1

Browse files
committed
ollama generate walk.
TODO-nrf: we need to add generate walks to every generation call.
1 parent 0388379 commit caca2a1

File tree

7 files changed

+214
-19
lines changed

7 files changed

+214
-19
lines changed

docs/examples/melp/lazy.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,28 @@
77
backend = OllamaModelBackend("granite4:latest")
88

99

10-
async def main(backend: Backend, ctx: Context):
11-
s1 = CBlock("What is 1+1? Respond with the number only.")
12-
s1_out, _ = await backend.generate_from_context(action=s1, ctx=SimpleContext())
13-
14-
s2 = CBlock("What is 2+2? Respond with the number only.")
15-
s2_out, _ = await backend.generate_from_context(action=s2, ctx=SimpleContext())
16-
17-
sc1 = SimpleComponent(
18-
instruction="What is x+y? Respond with the number only", x=s1_out, y=s2_out
10+
async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk:
11+
sc = SimpleComponent(
12+
instruction="What is x+y? Respond with the number only.", x=x, y=y
1913
)
14+
mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext())
15+
return mot
2016

21-
print(await s1_out.avalue())
22-
print(await s2_out.avalue())
2317

24-
sc1_out, _ = await backend.generate_from_context(action=sc1, ctx=SimpleContext())
25-
26-
print(await sc1_out.avalue())
18+
async def main(backend: Backend, ctx: Context):
19+
fibs = []
20+
for i in range(100):
21+
if i == 0 or i == 1:
22+
fibs.append(CBlock(f"{i + 1}"))
23+
else:
24+
fibs.append(await fib(backend, ctx, fibs[i - 1], fibs[i - 2]))
25+
26+
for x in fibs:
27+
match x:
28+
case ModelOutputThunk():
29+
print(await x.avalue())
30+
case CBlock():
31+
print(x.value)
2732

2833

2934
asyncio.run(main(backend, SimpleContext()))

docs/examples/melp/lazy_fib.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
from mellea.stdlib.span import Span, SimpleComponent
3+
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk
4+
from mellea.stdlib.requirement import Requirement
5+
from mellea.backends import Backend
6+
from mellea.backends.ollama import OllamaModelBackend
7+
from typing import Tuple
8+
9+
backend = OllamaModelBackend("granite4:latest")
10+
11+
12+
async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk:
13+
sc = SimpleComponent(
14+
instruction="What is x+y? Respond with the number only.", x=x, y=y
15+
)
16+
mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext())
17+
return mot
18+
19+
20+
async def fib_main(backend: Backend, ctx: Context):
21+
fibs = []
22+
for i in range(20):
23+
if i == 0 or i == 1:
24+
fibs.append(CBlock(f"{i}"))
25+
else:
26+
mot = await fib(backend, ctx, fibs[i - 1], fibs[i - 2])
27+
fibs.append(mot)
28+
29+
for x in enumerate(fibs):
30+
match x:
31+
case ModelOutputThunk():
32+
n = await x.avalue()
33+
print(n)
34+
case CBlock():
35+
print(x.value)
36+
37+
38+
asyncio.run(fib_main(backend, SimpleContext()))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import asyncio
2+
from mellea.stdlib.span import Span, SimpleComponent
3+
from mellea.stdlib.base import SimpleContext, Context, CBlock, ModelOutputThunk
4+
from mellea.stdlib.requirement import Requirement
5+
from mellea.backends import Backend
6+
from mellea.backends.ollama import OllamaModelBackend
7+
from typing import Tuple
8+
9+
backend = OllamaModelBackend("granite4:latest")
10+
11+
12+
async def _fib_sample(
13+
backend: Backend, ctx: Context, x: CBlock, y: CBlock
14+
) -> ModelOutputThunk | None:
15+
sc = SimpleComponent(
16+
instruction="What is x+y? Respond with the number only.", x=x, y=y
17+
)
18+
answer_mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext())
19+
20+
# This is a fundamental thing: it means computation must occur.
21+
# We need to be able to read this off at c.g. construction time.
22+
value = await answer_mot.avalue()
23+
24+
try:
25+
int(value)
26+
return answer_mot
27+
except:
28+
return None
29+
30+
31+
async def fib_sampling_version(
32+
backend: Backend, ctx: Context, x: CBlock, y: CBlock
33+
) -> ModelOutputThunk | None:
34+
for i in range(5):
35+
sample = await _fib_sample(backend, ctx, x, y)
36+
if sample is not None:
37+
return sample
38+
else:
39+
continue
40+
return None
41+
42+
43+
async def fib_sampling_version_main(backend: Backend, ctx: Context):
44+
fibs = []
45+
for i in range(20):
46+
if i == 0 or i == 1:
47+
fibs.append(CBlock(f"{i}"))
48+
else:
49+
mot = await fib_sampling_version(backend, ctx, fibs[i - 1], fibs[i - 2])
50+
fibs.append(mot)
51+
52+
for x_i, x in enumerate(fibs):
53+
match x:
54+
case ModelOutputThunk():
55+
n = await x.avalue()
56+
print(n)
57+
case CBlock():
58+
print(x.value)
59+
60+
61+
asyncio.run(fib_sampling_version_main(backend, SimpleContext()))

docs/examples/melp/states.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import mellea
2+
from mellea.stdlib.base import CBlock, Context, SimpleContext
3+
from mellea.stdlib.span import Span, SimpleComponent
4+
from mellea.backends import Backend
5+
from mellea.backends.ollama import OllamaModelBackend
6+
import asyncio
7+
8+
9+
async def main(backend: Backend, ctx: Context):
10+
a_states = "Alaska,Arizona,Arkansas".split(",")
11+
m_states = "Missouri", "Minnesota", "Montana", "Massachusetts"
12+
13+
a_state_pops = dict()
14+
for state in a_states:
15+
a_state_pops[state], _ = await backend.generate_from_context(
16+
CBlock(f"What is the population of {state}? Respond with an integer only."),
17+
SimpleContext(),
18+
)
19+
a_total_pop = SimpleComponent(
20+
instruction=CBlock(
21+
"What is the total population of these states? Respond with an integer only."
22+
),
23+
**a_state_pops,
24+
)
25+
a_state_total, _ = await backend.generate_from_context(a_total_pop, SimpleContext())
26+
27+
m_state_pops = dict()
28+
for state in m_states:
29+
m_state_pops[state], _ = await backend.generate_from_context(
30+
CBlock(f"What is the population of {state}? Respond with an integer only."),
31+
SimpleContext(),
32+
)
33+
m_total_pop = SimpleComponent(
34+
instruction=CBlock(
35+
"What is the total population of these states? Respond with an integer only."
36+
),
37+
**m_state_pops,
38+
)
39+
m_state_total, _ = await backend.generate_from_context(m_total_pop, SimpleContext())
40+
41+
print(await a_state_total.avalue())
42+
print(await m_state_total.avalue())
43+
44+
45+
backend = OllamaModelBackend(model_id="granite4:latest")
46+
asyncio.run(main(backend, SimpleContext()))

mellea/backends/_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from __future__ import annotations
22

33
import inspect
4+
import itertools
45
from collections.abc import Callable
56
from typing import Any, Literal
67

78
from mellea.backends.formatter import Formatter
89
from mellea.backends.tools import parse_tools
910
from mellea.helpers.fancy_logger import FancyLogger
10-
from mellea.stdlib.base import CBlock, Component, Context, ModelToolCall
11+
from mellea.stdlib.base import (
12+
CBlock,
13+
Component,
14+
Context,
15+
ModelOutputThunk,
16+
ModelToolCall,
17+
)
1118
from mellea.stdlib.chat import Message
1219
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
1320

@@ -80,3 +87,15 @@ def to_tool_calls(
8087
if len(model_tool_calls) > 0:
8188
return model_tool_calls
8289
return None
90+
91+
92+
def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]:
93+
"""Returns the generation walk ordering for a Span."""
94+
match c:
95+
case ModelOutputThunk() if not c.is_computed():
96+
return [c]
97+
case CBlock():
98+
return []
99+
case Component():
100+
parts_walk = [generate_walk(p) for p in c.parts()]
101+
return itertools.chain.from_iterable(parts_walk) # aka flatten

mellea/backends/ollama.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import mellea.backends.model_ids as model_ids
1313
from mellea.backends import BaseModelSubclass
14+
from mellea.backends._utils import generate_walk
1415
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
1516
from mellea.backends.model_ids import ModelIdentifier
1617
from mellea.backends.tools import (
@@ -294,6 +295,12 @@ async def generate_from_chat_context(
294295
Raises:
295296
RuntimeError: If not called from a thread with a running event loop.
296297
"""
298+
# Start by awaiting any necessary computation.
299+
_computed = [await todo.avalue() for todo in generate_walk(action)]
300+
FancyLogger.get_logger().info(
301+
f"generate_from_chat_context awaited on {len(_computed)} uncomputed mots."
302+
)
303+
297304
model_opts = self._simplify_and_merge(model_options)
298305

299306
linearized_context = ctx.view_for_generation()
@@ -408,9 +415,14 @@ async def generate_from_raw(
408415

409416
model_opts = self._simplify_and_merge(model_options)
410417

418+
for act in actions:
419+
for todo in generate_walk(act):
420+
await todo.avalue()
421+
422+
prompts = [self.formatter.print(action) for action in actions]
423+
411424
# Ollama doesn't support "batching". There's some ability for concurrency. Use that here.
412425
# See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests.
413-
prompts = [self.formatter.print(action) for action in actions]
414426

415427
# Run async so that we can make use of Ollama's concurrency.
416428
coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = []

mellea/stdlib/span/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ def __init__(self, **kwargs):
2323

2424
def parts(self):
2525
"""Returns the values of the kwargs."""
26-
return self._kwargs.values()
26+
return list(self._kwargs.values())
2727

2828
def _kwargs_type_check(self, kwargs):
2929
for key in kwargs.keys():
3030
value = kwargs[key]
3131
assert issubclass(type(value), Component) or issubclass(
3232
type(value), CBlock
33-
), f"Expected span but found {type(value)}"
33+
), f"Expected span but found {type(value)} of value: {value}"
3434
assert type(key) is str
3535
return True
3636

@@ -41,9 +41,23 @@ def make_simple_string(kwargs):
4141
[f"<|{key}|>{value}</|{key}|>" for (key, value) in kwargs.items()]
4242
)
4343

44+
@staticmethod
45+
def make_json_string(kwargs):
46+
"""Uses json."""
47+
str_args = dict()
48+
for key in kwargs.keys():
49+
match kwargs[key]:
50+
case ModelOutputThunk() | CBlock():
51+
str_args[key] = kwargs[key].value
52+
case Component():
53+
str_args[key] = kwargs[key].format_for_llm()
54+
import json
55+
56+
return json.dumps(str_args)
57+
4458
def format_for_llm(self):
4559
"""Uses a string rep."""
46-
return SimpleComponent.make_simple_string(self._kwargs)
60+
return SimpleComponent.make_json_string(self._kwargs)
4761
# """ Uses a simple tagging structure that needs to be changed in the future. """
4862
# return TemplateRepresentation(
4963
# obj=self,

0 commit comments

Comments
 (0)