Skip to content

Commit 6bc96be

Browse files
committed
use m.act as underlying func for generating
1 parent f647d74 commit 6bc96be

File tree

3 files changed

+189
-191
lines changed

3 files changed

+189
-191
lines changed

mellea/stdlib/genslot.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pydantic import BaseModel, Field, create_model
1010

1111
from mellea.stdlib.base import Component, TemplateRepresentation
12-
from mellea.stdlib.session import get_session
12+
from mellea.stdlib.session import MelleaSession, get_session
1313

1414
P = ParamSpec("P")
1515
R = TypeVar("R")
@@ -154,7 +154,7 @@ def __init__(self, func: Callable[P, R]):
154154

155155
def __call__(
156156
self,
157-
m=None,
157+
m: MelleaSession | None = None,
158158
model_options: dict | None = None,
159159
*args: P.args,
160160
**kwargs: P.kwargs,
@@ -180,13 +180,11 @@ def __call__(
180180

181181
response_model = create_response_format(self._function._func)
182182

183-
response = m.genslot(
184-
slot_copy, model_options=model_options, format=response_model
185-
)
183+
response = m.act(slot_copy, format=response_model, model_options=model_options)
186184

187185
function_response: FunctionResponse[R] = response_model.model_validate_json(
188-
response.value
189-
) # type: ignore
186+
response.value # type: ignore
187+
)
190188

191189
return function_response.result
192190

mellea/stdlib/sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def select_from_failure(
153153
sampled_actions: list[Component],
154154
sampled_results: list[ModelOutputThunk],
155155
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
156-
):
156+
) -> int:
157157
"""This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success.
158158
159159
Args:
@@ -356,17 +356,17 @@ def select_from_failure(
356356

357357
@staticmethod
358358
def repair(
359-
context: Context,
359+
ctx: Context,
360360
past_actions: list[Component],
361361
past_results: list[ModelOutputThunk],
362362
past_val: list[list[tuple[Requirement, ValidationResult]]],
363363
) -> Component:
364-
assert isinstance(context, LinearContext), (
364+
assert isinstance(ctx, LinearContext), (
365365
" Need linear context to run agentic sampling."
366366
)
367367

368368
# add failed execution to chat history
369-
context.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))
369+
ctx.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))
370370

371371
last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]]
372372
last_failed_reqs_str = "* " + "\n* ".join(

0 commit comments

Comments
 (0)