Skip to content

Commit 4ae6d7c

Browse files
feat: new context, new sampling,. (#166)
* non-breaking: Context -> LegacyContext * non-breaking: *Context -> Legacy*Context * Context, SimpleContext, ChatContext (nee:LinearContext) without log mechanics... * expand -> add * remove RootContext in favor of _is_root attribute. * . * actions_for_available_tools is back * all backends adjusted to new context structure. Session modifications left open.. (unsafe commit) * feat: add top level funcs and start sample changes * context adjustments and tests * maybe breaking: removing chat, instruct,.. from mellea/__init__.py * untested session and sampling changes * backends should add ACTION and MOT. * as_chat_history adapted * fix: ollama client instantiate to prevent stale event loops in httpx sessions * fix: mypy issues * Renaming: ctx.as_list(last_n_components..) ctx.render_for_generation() --> ctx.view_for_generation() * name cleanup * adding contexts to sampling * 1) handling sampling results and context 2) making rejection-sampling default * adding warning * name clarity * remove unnecessary function * updated tests: - all: convert to new cotext - contextual_session test is commented out with TODO - session_ctx test is removed because logging changed. * fix: watsonx errors * fix: sampling and vision tests * examples react - error to be looked into by @jal * fixed react - thanks to @jal * examples for alora (untested) and gen slots (tested) * more examples updated * more examples updated * remove old context * fix: best of n and misc changes * fix: improve runtime of best of n sampling * fix: add tests * fix: remove references to LinearContext --------- Co-authored-by: jakelorocco <[email protected]>
1 parent 4ee56a9 commit 4ae6d7c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1565
-1453
lines changed

docs/examples/aLora/101_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import time
22

3-
from mellea import LinearContext, MelleaSession
3+
from mellea import MelleaSession
44
from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora
55
from mellea.backends.cache import SimpleLRUCache
66
from mellea.backends.huggingface import LocalHFBackend
7-
from mellea.stdlib.base import GenerateLog
7+
from mellea.stdlib.base import ChatContext, GenerateLog
88
from mellea.stdlib.requirement import ALoraRequirement, Requirement
99

1010
# Define a backend and add the constraint aLora
@@ -22,7 +22,7 @@
2222
backend.add_alora(custom_stembolt_failure_constraint)
2323

2424
# Create M session
25-
m = MelleaSession(backend, ctx=LinearContext())
25+
m = MelleaSession(backend, ctx=ChatContext())
2626

2727
# define a requirement
2828
failure_check = ALoraRequirement(

docs/examples/agents/react.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
import json
44
from collections.abc import Callable
5-
from typing import Literal, Unpack
5+
from typing import Literal
66

77
import pydantic
88
from jinja2 import Template
@@ -13,6 +13,7 @@
1313
import mellea.stdlib
1414
import mellea.stdlib.base
1515
import mellea.stdlib.chat
16+
from mellea.stdlib.base import ChatContext
1617

1718
react_system_template: Template = Template(
1819
"""Answer the user's question as best you can.
@@ -83,7 +84,7 @@ def call_tool(self, tool: ReactTool, kwargs_json: str):
8384
def tool_name_schema(self):
8485
names = self.tool_names()
8586
fields = dict()
86-
fields["tool"] = Literal[Unpack[names]]
87+
fields["tool"] = Literal[*names]
8788
return pydantic.create_model("ToolSelectionSchema", **fields)
8889

8990
def get_tool_from_schema(self, content: str):
@@ -103,7 +104,7 @@ def react(
103104
react_toolbox: ReactToolbox,
104105
):
105106
assert m.ctx.is_chat_context, "ReACT requires a chat context."
106-
test_ctx_lin = m.ctx.render_for_generation()
107+
test_ctx_lin = m.ctx.view_for_generation()
107108
assert test_ctx_lin is not None and len(test_ctx_lin) == 0, (
108109
"ReACT expects a fresh context."
109110
)
@@ -114,8 +115,9 @@ def react(
114115
)
115116

116117
# Add the system prompt and the goal to the chat history.
117-
m.ctx.insert(mellea.stdlib.chat.Message(role="system", content=_sys_prompt))
118-
m.ctx.insert(mellea.stdlib.chat.Message(role="user", content=f"{goal}"))
118+
m.ctx = m.ctx.add(
119+
mellea.stdlib.chat.Message(role="system", content=_sys_prompt)
120+
).add(mellea.stdlib.chat.Message(role="user", content=f"{goal}"))
119121

120122
# The main ReACT loop as a dynamic program:
121123
# ( ?(not done) ;
@@ -156,7 +158,7 @@ def react(
156158

157159
print("### Observation")
158160
tool_output = react_toolbox.call_tool(selected_tool, act_args.content)
159-
m.ctx.insert(mellea.stdlib.chat.Message(role="tool", content=tool_output))
161+
m.ctx = m.ctx.add(mellea.stdlib.chat.Message(role="tool", content=tool_output))
160162
print(tool_output)
161163

162164
print("### Done Check")
@@ -178,7 +180,7 @@ def react(
178180

179181

180182
if __name__ == "__main__":
181-
m = mellea.start_session(ctx=mellea.stdlib.base.LinearContext())
183+
m = mellea.start_session(ctx=ChatContext())
182184

183185
def zip_lookup_tool_fn(city: str):
184186
"""Returns the ZIP code for the `city`."""

docs/examples/agents/react_instruct.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
import json
44
from collections.abc import Callable
5-
from typing import Literal, Unpack
5+
from typing import Literal
66

77
import pydantic
88
from jinja2 import Template
@@ -11,6 +11,7 @@
1111
import mellea.stdlib
1212
import mellea.stdlib.base
1313
import mellea.stdlib.chat
14+
from mellea.stdlib.base import ChatContext
1415

1516
react_system_template: Template = Template(
1617
"""Answer the user's question as best you can.
@@ -81,7 +82,7 @@ def call_tool(self, tool: ReactTool, kwargs_json: str):
8182
def tool_name_schema(self):
8283
names = self.tool_names()
8384
fields = dict()
84-
fields["tool"] = Literal[Unpack[names]]
85+
fields["tool"] = Literal[*names]
8586
return pydantic.create_model("ToolSelectionSchema", **fields)
8687

8788
def get_tool_from_schema(self, content: str):
@@ -101,7 +102,7 @@ def react(
101102
react_toolbox: ReactToolbox,
102103
):
103104
assert m.ctx.is_chat_context, "ReACT requires a chat context."
104-
test_ctx_lin = m.ctx.render_for_generation()
105+
test_ctx_lin = m.ctx.view_for_generation()
105106
assert test_ctx_lin is not None and len(test_ctx_lin) == 0, (
106107
"ReACT expects a fresh context."
107108
)
@@ -112,8 +113,9 @@ def react(
112113
)
113114

114115
# Add the system prompt and the goal to the chat history.
115-
m.ctx.insert(mellea.stdlib.chat.Message(role="system", content=_sys_prompt))
116-
m.ctx.insert(mellea.stdlib.chat.Message(role="user", content=f"{goal}"))
116+
m.ctx = m.ctx.add(
117+
mellea.stdlib.chat.Message(role="system", content=_sys_prompt)
118+
).add(mellea.stdlib.chat.Message(role="user", content=f"{goal}"))
117119

118120
# The main ReACT loop as a dynamic program:
119121
# ( ?(not done) ;
@@ -159,7 +161,7 @@ def react(
159161

160162
print("### Observation")
161163
tool_output = react_toolbox.call_tool(selected_tool, act_args_val)
162-
m.ctx.insert(mellea.stdlib.chat.Message(role="tool", content=tool_output))
164+
m.ctx = m.ctx.add(mellea.stdlib.chat.Message(role="tool", content=tool_output))
163165
print(tool_output)
164166

165167
print("### Done Check")
@@ -187,7 +189,7 @@ def react(
187189

188190

189191
if __name__ == "__main__":
190-
m = mellea.start_session(ctx=mellea.stdlib.base.LinearContext())
192+
m = mellea.start_session(ctx=ChatContext())
191193

192194
def zip_lookup_tool_fn(city: str):
193195
"""Returns the ZIP code for the `city`."""

docs/examples/generative_slots/generate_with_context.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from mellea import LinearContext, generative, start_session
1+
from mellea import generative, start_session
22
from mellea.backends.types import ModelOption
3-
from mellea.stdlib.base import CBlock
3+
from mellea.stdlib.base import CBlock, ChatContext
44

55
# Generative slots can be used with sessions that have context.
66
# By utilizing context, you can change the results of several
@@ -34,7 +34,7 @@ def give_feedback(essay: str) -> list[str]:
3434

3535
if __name__ == "__main__":
3636
m = start_session(
37-
ctx=LinearContext(), model_options={ModelOption.MAX_NEW_TOKENS: 100}
37+
ctx=ChatContext(), model_options={ModelOption.MAX_NEW_TOKENS: 100}
3838
)
3939

4040
text = """
@@ -55,7 +55,7 @@ def give_feedback(essay: str) -> list[str]:
5555

5656
# If you have a set of generative functions, you can tweak them all by
5757
# adding context to the session they are running in.
58-
m.ctx.insert(
58+
m.ctx = m.ctx.add(
5959
CBlock(
6060
"You are an elementary school teacher. "
6161
"Any grades and feedback that you give should keep that in mind. Remember to be "
@@ -74,7 +74,7 @@ def give_feedback(essay: str) -> list[str]:
7474

7575
# And, let's reset the context and try a different grading style.
7676
m.reset()
77-
m.ctx.insert(
77+
m.ctx = m.ctx.add(
7878
CBlock(
7979
"You are a grammarian that is focused solely on spelling and syntax, "
8080
"not on the content of essays. When giving grades and feedback, focus "

docs/examples/helper/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .helpers import Any, fill, w
1+
from .helpers import req_print, w

docs/examples/helper/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from textwrap import fill
22
from typing import Any
33

4+
from mellea.stdlib.requirement import Requirement, ValidationResult
5+
46

57
# Just for printing stuff nicely...
68
def w(x: Any) -> str:
79
return fill(str(x), width=120, replace_whitespace=False)
10+
11+
12+
def req_print(rv_list: list[tuple[Requirement, ValidationResult]]) -> str:
13+
parts = [f"{bool(rv[1])}\t: {rv[0].description}" for rv in rv_list]
14+
return "\n".join(parts)

docs/examples/image_text_models/vision_ollama_chat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from PIL import Image
44

5-
from mellea import LinearContext, start_session
6-
from mellea.stdlib.base import ImageBlock
5+
from mellea import start_session
6+
from mellea.stdlib.base import ChatContext, ImageBlock
77

8-
m = start_session(model_id="granite3.2-vision", ctx=LinearContext())
9-
# m = start_session(model_id="llava", ctx=LinearContext())
8+
m = start_session(model_id="granite3.2-vision", ctx=ChatContext())
9+
# m = start_session(model_id="llava", ctx=ChatContext())
1010

1111
# load image
1212
test_img = Image.open("pointing_up.jpg")

docs/examples/instruct_validate_repair/101_email.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,19 @@
11
# This is the 101 example for using `session` and `instruct`.
22
# helper function to wrap text
33
from docs.examples.helper import w
4-
from mellea import instruct, start_session
4+
from mellea import start_session
55
from mellea.backends.types import ModelOption
66

7-
# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
8-
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}):
9-
# write an email
10-
email_v1 = instruct("Write an email to invite all interns to the office party.")
11-
127
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m:
138
# write an email
149
email_v1 = m.instruct("Write an email to invite all interns to the office party.")
10+
print(m.last_prompt())
1511

1612
# print result
1713
print(f"***** email ****\n{w(email_v1)}\n*******")
1814

1915
# ************** END *************
2016

21-
22-
# # optionally: print the debug log for the last instruction on the context
23-
# from mellea.stdlib.base import GenerateLog
24-
# _, log = m.ctx.last_output_and_logs()
25-
# if isinstance(log, GenerateLog): # should be
26-
# print(f"Prompt:\n{w(log.prompt)}") # print prompt
27-
2817
# # start_session() is equivalent to:
2918
# from mellea.backends import model_ids
3019
# from mellea.backends.ollama import OllamaModelBackend

docs/examples/instruct_validate_repair/101_email_with_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
66
m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200})
77

8-
# write an email
8+
# write an email with automatic requirement checking.
99
email_v1 = m.instruct(
1010
"Write an email to invite all interns to the office party.",
1111
requirements=["be formal", "Use 'Dear interns' as greeting."],

docs/examples/instruct_validate_repair/101_email_with_validate.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
1-
from docs.examples.helper import w
1+
from docs.examples.helper import req_print, w
22
from mellea import start_session
33
from mellea.backends.types import ModelOption
44
from mellea.stdlib.sampling import RejectionSamplingStrategy
55

66
# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
77
m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200})
88

9-
email_v1 = m.instruct(
9+
email_v2_samples = m.instruct(
1010
"Write an email to invite all interns to the office party.",
1111
requirements=["be formal", "Use 'Dear interns' as greeting."],
1212
strategy=RejectionSamplingStrategy(loop_budget=3),
13+
return_sampling_results=True,
1314
)
1415

15-
# print result
16-
print(f"***** email ****\n{w(email_v1)}\n*******")
16+
if email_v2_samples.success:
17+
print(f"Success: \n{w(email_v2_samples.result)}")
18+
print(
19+
f"===> Requirement for this sample: \n{req_print(email_v2_samples.sample_validations[-1])}"
20+
)
21+
else:
22+
print(f"Failure: \n{w(email_v2_samples.result)}")
23+
selected_index = email_v2_samples.sample_generations.index(email_v2_samples.result)
24+
print(
25+
f"===> Requirement for this sample: \n{req_print(email_v2_samples.sample_validations[selected_index])}"
26+
)
1727

1828
# # [optional] get logs for all loops:
1929
# from mellea.stdlib.base import GenerateLog

0 commit comments

Comments
 (0)