Skip to content

Commit 4a28977

Browse files
committed
Merge remote-tracking branch 'origin/main' into tests-isolation
Signed-off-by: elronbandel <[email protected]>
2 parents 6a406e6 + 77d2f08 commit 4a28977

File tree

18 files changed

+275
-115
lines changed

18 files changed

+275
-115
lines changed

docs/examples/agents/react.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def react(
103103
react_toolbox: ReactToolbox,
104104
):
105105
assert m.ctx.is_chat_context, "ReACT requires a chat context."
106-
test_ctx_lin = m.ctx.linearize()
106+
test_ctx_lin = m.ctx.render_for_generation()
107107
assert test_ctx_lin is not None and len(test_ctx_lin) == 0, (
108108
"ReACT expects a fresh context."
109109
)

docs/examples/agents/react_instruct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def react(
101101
react_toolbox: ReactToolbox,
102102
):
103103
assert m.ctx.is_chat_context, "ReACT requires a chat context."
104-
test_ctx_lin = m.ctx.linearize()
104+
test_ctx_lin = m.ctx.render_for_generation()
105105
assert test_ctx_lin is not None and len(test_ctx_lin) == 0, (
106106
"ReACT expects a fresh context."
107107
)

docs/tutorial.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -968,15 +968,15 @@ Let's look at how this agent is implemented in Mellea:
968968
```python
969969
# file: https://github.com/generative-computing/mellea/blob/main/docs/examples/agents/react.py#L99
970970
def react(
971-
m: mellea.MelleaSession,
972-
goal: str,
973-
react_toolbox: ReactToolbox,
974-
budget: int=5,
971+
m: mellea.MelleaSession,
972+
goal: str,
973+
react_toolbox: ReactToolbox,
974+
budget: int = 5,
975975
):
976976
assert m.ctx.is_chat_context, "ReACT requires a chat context."
977-
test_ctx_lin = m.ctx.linearize()
977+
test_ctx_lin = m.ctx.render_for_generation()
978978
assert (
979-
test_ctx_lin is not None and len(test_ctx_lin) == 0
979+
test_ctx_lin is not None and len(test_ctx_lin) == 0
980980
), "ReACT expects a fresh context."
981981

982982
# Construct the system prompt for ReACT.
@@ -1006,15 +1006,17 @@ def react(
10061006
# model_options={mellea.backends.types.ModelOption.TOOLS: react_toolbox.tools_dict()},
10071007
format=react_toolbox.tool_name_schema(),
10081008
)
1009-
selected_tool: ReactTool = react_toolbox.get_tool_from_schema(act.content)
1009+
selected_tool: ReactTool = react_toolbox.get_tool_from_schema(
1010+
act.content)
10101011
print(selected_tool.get_name())
10111012

10121013
print(f"### Arguments for action")
10131014
act_args = m.chat(
10141015
"Choose arguments for the tool. Respond using JSON and include only the tool arguments in your response.",
10151016
format=selected_tool.args_schema(),
10161017
)
1017-
print(f"```json\n{json.dumps(json.loads(act_args.content), indent=2)}\n```")
1018+
print(
1019+
f"```json\n{json.dumps(json.loads(act_args.content), indent=2)}\n```")
10181020

10191021
# TODO: handle exceptions.
10201022
print("### Observation")

mellea/backends/formatter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
import sys
77
from collections.abc import Iterable, Mapping
8+
from dataclasses import fields
89
from typing import Any
910

1011
import jinja2
@@ -166,7 +167,7 @@ def print_context(self, ctx: Context) -> str:
166167
)
167168
match ctx:
168169
case LinearContext():
169-
linearized_ctx = ctx.linearize()
170+
linearized_ctx = ctx.render_for_generation()
170171
assert linearized_ctx is not None
171172
return "".join([self.print(x) for x in linearized_ctx])
172173
case SimpleContext():
@@ -396,14 +397,13 @@ def _get_model_id(self) -> str:
396397
"model_id was neither a `str` nor `ModelIdentifier`"
397398
)
398399

399-
# Go through the ModelIdentifier's fields, find one that isn't `"None"` or `""`.
400-
ids = [model_id.hf_model_name, model_id.ollama_name]
401-
model_id = ""
402-
for val in ids:
403-
if val != "None" and val != "":
404-
model_id = val # type: ignore
405-
break
406-
return model_id
400+
# Go through the ModelIdentifier's fields, find one that can be matched against.
401+
for field in fields(model_id):
402+
val = getattr(model_id, field.name)
403+
if val is not None and val != "":
404+
return val
405+
406+
return "" # Cannot match against any model identifiers. Will ultimately use default.
407407

408408

409409
def _simplify_model_string(input: str) -> str:

mellea/backends/huggingface.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
import json
1313
import os
1414
from collections.abc import Callable
15-
from typing import Any, Optional
15+
from typing import TYPE_CHECKING, Any, Optional
1616

1717
import outlines
1818
import outlines_core
1919
import torch
20-
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
2120
from transformers import (
2221
AutoModelForCausalLM,
2322
AutoTokenizer,
@@ -26,7 +25,6 @@
2625
PreTrainedTokenizer,
2726
set_seed,
2827
)
29-
from transformers.generation import GenerateDecoderOnlyOutput
3028

3129
from mellea.backends import BaseModelSubclass
3230
from mellea.backends.aloras import Alora, AloraBackendMixin
@@ -52,6 +50,9 @@
5250
from mellea.stdlib.chat import Message
5351
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
5452

53+
if TYPE_CHECKING:
54+
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
55+
5556
assert outlines, "outlines needs to be present to make outlines_core work"
5657

5758
"""A configuration type for the unhappy path: Tokenizer * Model * torch device string
@@ -160,17 +161,17 @@ def __init__(
160161
self._cache = cache if cache is not None else SimpleLRUCache(3)
161162

162163
# Used when running aLoRAs with this backend.
163-
self._alora_model: aLoRAPeftModelForCausalLM | None = None
164+
self._alora_model: "aLoRAPeftModelForCausalLM | None" = None # noqa: UP037
164165
# ALoras that have been loaded for this model.
165166
self._aloras: dict[str, HFAlora] = {}
166167

167168
@property
168-
def alora_model(self) -> aLoRAPeftModelForCausalLM | None:
169+
def alora_model(self) -> "aLoRAPeftModelForCausalLM | None": # noqa: UP037
169170
"""The ALora model."""
170171
return self._alora_model
171172

172173
@alora_model.setter
173-
def alora_model(self, model: aLoRAPeftModelForCausalLM | None):
174+
def alora_model(self, model: "aLoRAPeftModelForCausalLM | None"): # noqa: UP037
174175
"""Sets the ALora model. This should only happen once in a backend's lifetime."""
175176
assert self._alora_model is None
176177
self._alora_model = model
@@ -239,7 +240,7 @@ def _generate_from_context_alora(
239240
"This code block should not execute unless there is a 'constraint' alora loaded."
240241
)
241242
# Construct the linearized context. This is very similar to normal generation.
242-
linearized_ctx = ctx.linearize()
243+
linearized_ctx = ctx.render_for_generation()
243244
assert linearized_ctx is not None and len(linearized_ctx) > 1
244245
msgs = self.formatter.to_chat_messages(linearized_ctx)
245246
user_message, assistant_message = msgs[-2].content, msgs[-1].content
@@ -286,7 +287,7 @@ def _generate_from_context_standard(
286287
# Otherwise, we will linearize the context and treat it as a raw input.
287288
decoded_result: str | None = None
288289
if ctx.is_chat_context:
289-
linearized_ctx = ctx.linearize()
290+
linearized_ctx = ctx.render_for_generation()
290291
assert linearized_ctx is not None, (
291292
"If ctx.is_chat_context, then the context should be linearizable."
292293
)
@@ -624,6 +625,8 @@ def add_alora(self, alora: HFAlora):
624625
Args:
625626
alora (str): identifier for the ALora adapter
626627
"""
628+
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
629+
627630
assert issubclass(alora.__class__, HFAlora), (
628631
f"cannot add an ALora of type {alora.__class__} to model; must inherit from {HFAlora.__class__}"
629632
)

mellea/backends/ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def generate_from_chat_context(
263263
"""
264264
model_opts = self._simplify_and_merge(model_options)
265265

266-
linearized_context = ctx.linearize()
266+
linearized_context = ctx.render_for_generation()
267267
assert linearized_context is not None, (
268268
"Cannot generate from a non-linear context in a FormatterBackend."
269269
)

mellea/backends/openai.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66
import json
77
from collections.abc import Callable
88
from enum import Enum
9-
from typing import Any
9+
from typing import TYPE_CHECKING, Any
1010
from urllib.parse import urlparse
1111

1212
import openai
1313
import requests
1414
from huggingface_hub import snapshot_download
1515
from openai.types.chat import ChatCompletion
1616
from openai.types.completion import Completion
17-
from transformers import AutoTokenizer
18-
from transformers.tokenization_utils import PreTrainedTokenizer
1917

2018
import mellea.backends.model_ids as model_ids
21-
from cli.serve.models import ChatCompletionMessage
2219
from mellea.backends import BaseModelSubclass
2320
from mellea.backends.aloras import Alora, AloraBackendMixin
2421
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
@@ -38,6 +35,9 @@
3835
from mellea.stdlib.chat import Message
3936
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
4037

38+
if TYPE_CHECKING:
39+
from transformers.tokenization_utils import PreTrainedTokenizer
40+
4141
openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"
4242

4343

@@ -328,7 +328,7 @@ def _generate_from_chat_context_alora(
328328
)
329329

330330
# Construct the linearized context. This is very similar to normal generation.
331-
linearized_ctx = ctx.linearize()
331+
linearized_ctx = ctx.render_for_generation()
332332
assert linearized_ctx is not None and len(linearized_ctx) > 1
333333
msgs = self.formatter.to_chat_messages(linearized_ctx)
334334
user_message, assistant_message = msgs[-2].content, msgs[-1].content
@@ -363,7 +363,7 @@ def _generate_from_chat_context_standard(
363363
model_opts = self._simplify_and_merge(
364364
model_options, is_chat_context=ctx.is_chat_context
365365
)
366-
linearized_context = ctx.linearize()
366+
linearized_context = ctx.render_for_generation()
367367
assert linearized_context is not None, (
368368
"Cannot generate from a non-linear context in a FormatterBackend."
369369
)
@@ -639,10 +639,12 @@ def get_aloras(self) -> list[Alora]:
639639

640640
def apply_chat_template(self, chat: list[dict[str, str]]):
641641
"""Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id)."""
642+
from transformers import AutoTokenizer
643+
642644
if not hasattr(self, "_tokenizer"):
643645
match _server_type(self._base_url):
644646
case _ServerType.LOCALHOST:
645-
self._tokenizer: PreTrainedTokenizer = (
647+
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
646648
AutoTokenizer.from_pretrained(self._hf_model_id)
647649
)
648650
case _ServerType.OPENAI:

mellea/backends/watsonx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def generate_from_chat_context(
220220
model_options, is_chat_context=ctx.is_chat_context
221221
)
222222

223-
linearized_context = ctx.linearize()
223+
linearized_context = ctx.render_for_generation()
224224
assert linearized_context is not None, (
225225
"Cannot generate from a non-linear context in a FormatterBackend."
226226
)

mellea/stdlib/base.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,13 @@ def _hash_for_kv_cache(self):
157157
...
158158

159159
@abc.abstractmethod
160-
def linearize(self) -> list[Component | CBlock] | None:
161-
"""Provides a linear list of context components. This is not always possible, or None if that is not possible to construct."""
160+
def render_for_generation(self) -> list[Component | CBlock] | None:
161+
"""Provides a linear list of context components to use for generation, or None if that is not possible to construct."""
162+
...
163+
164+
@abc.abstractmethod
165+
def full_event_log(self) -> list[Component | CBlock]:
166+
"""Provides a list of all events stored in the context."""
162167
...
163168

164169
@abc.abstractmethod
@@ -262,6 +267,10 @@ def last_output_and_logs(
262267
)
263268
return last, log[0]
264269

270+
def full_event_log(self) -> list[Component | CBlock]:
271+
"""Returns the underlying _ctx."""
272+
return self._ctx
273+
265274
def last_turn(self):
266275
"""The last input/output turn of the context."""
267276
if len(self._ctx) == 0:
@@ -335,8 +344,8 @@ def insert_turn(
335344
if turn.output:
336345
self.insert(turn.output, generate_logs=generate_logs)
337346

338-
def linearize(self) -> list[Component | CBlock] | None:
339-
"""Returns the underlying _ctx list."""
347+
def render_for_generation(self) -> list[Component | CBlock] | None:
348+
"""Returns the underlying _ctx list for generation."""
340349
return self._ctx
341350

342351
def is_chat_history(self):
@@ -372,7 +381,7 @@ def __init__(self):
372381
super().__init__()
373382
self.is_chat_context = True
374383

375-
def linearize(self) -> list[Component | CBlock] | None:
384+
def render_for_generation(self) -> list[Component | CBlock] | None:
376385
"""Uses _ctx ordering."""
377386
return []
378387

mellea/stdlib/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None:
111111
case _:
112112
return None
113113

114-
linearized_ctx = ctx.linearize()
115-
if linearized_ctx is None:
114+
all_ctx_events = ctx.full_event_log()
115+
if all_ctx_events is None:
116116
raise Exception("Trying to cast a non-linear history into a chat history.")
117117
else:
118-
history = [_to_msg(c) for c in linearized_ctx]
118+
history = [_to_msg(c) for c in all_ctx_events]
119119
assert None not in history, "Could not render this context as a chat history."
120120
return history # type: ignore

0 commit comments

Comments
 (0)