Skip to content

Commit df7d2da

Browse files
author
Maxwell Crouse [email protected]
committed
adding first test case
1 parent 788f339 commit df7d2da

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

mellea/backends/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def generate_from_context(
311311
tool_calls=tool_calls,
312312
labels=labels,
313313
)
314-
#
314+
315+
# only add action to context if provided
315316
if action is not None:
316317
ctx = ctx.add(action, labels=labels)
317318

mellea/stdlib/functional.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def act(
4646
format: type[BaseModelSubclass] | None = None,
4747
model_options: dict | None = None,
4848
tool_calls: bool = False,
49+
labels: Sequence[str] | None = None,
4950
) -> tuple[ModelOutputThunk, Context]: ...
5051

5152

@@ -61,6 +62,7 @@ def act(
6162
format: type[BaseModelSubclass] | None = None,
6263
model_options: dict | None = None,
6364
tool_calls: bool = False,
65+
labels: Sequence[str] | None = None,
6466
) -> SamplingResult: ...
6567

6668

@@ -75,6 +77,7 @@ def act(
7577
format: type[BaseModelSubclass] | None = None,
7678
model_options: dict | None = None,
7779
tool_calls: bool = False,
80+
labels: Sequence[str] | None = None,
7881
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
7982
"""Runs a generic action, and adds both the action and the result to the context.
8083
@@ -88,6 +91,7 @@ def act(
8891
format: if set, the BaseModel to use for constrained decoding.
8992
model_options: additional model options, which will upsert into the model/backend's defaults.
9093
tool_calls: if true, tool calling is enabled.
94+
labels: if provided, restrict generation to context nodes with matching types.
9195
9296
Returns:
9397
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
@@ -104,6 +108,7 @@ def act(
104108
model_options=model_options,
105109
tool_calls=tool_calls,
106110
silence_context_type_warning=True, # We can safely silence this here since it's in a sync function.
111+
labels=labels,
107112
) # type: ignore[call-overload]
108113
# Mypy doesn't like the bool for return_sampling_results.
109114
)
@@ -129,6 +134,7 @@ def instruct(
129134
format: type[BaseModelSubclass] | None = None,
130135
model_options: dict | None = None,
131136
tool_calls: bool = False,
137+
labels: Sequence[str] | None = None,
132138
) -> tuple[ModelOutputThunk, Context]: ...
133139

134140

@@ -150,6 +156,7 @@ def instruct(
150156
format: type[BaseModelSubclass] | None = None,
151157
model_options: dict | None = None,
152158
tool_calls: bool = False,
159+
labels: Sequence[str] | None = None,
153160
) -> SamplingResult: ...
154161

155162

@@ -170,6 +177,7 @@ def instruct(
170177
format: type[BaseModelSubclass] | None = None,
171178
model_options: dict | None = None,
172179
tool_calls: bool = False,
180+
labels: Sequence[str] | None = None,
173181
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
174182
"""Generates from an instruction.
175183
@@ -189,6 +197,7 @@ def instruct(
189197
model_options: Additional model options, which will upsert into the model/backend's defaults.
190198
tool_calls: If true, tool calling is enabled.
191199
images: A list of images to be used in the instruction or None if none.
200+
labels: if provided, restrict generation to context nodes with matching types.
192201
193202
Returns:
194203
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
@@ -221,6 +230,7 @@ def instruct(
221230
format=format,
222231
model_options=model_options,
223232
tool_calls=tool_calls,
233+
labels=labels,
224234
) # type: ignore[call-overload]
225235

226236

@@ -235,6 +245,7 @@ def chat(
235245
format: type[BaseModelSubclass] | None = None,
236246
model_options: dict | None = None,
237247
tool_calls: bool = False,
248+
labels: Sequence[str] | None = None,
238249
) -> tuple[Message, Context]:
239250
"""Sends a simple chat message and returns the response. Adds both messages to the Context."""
240251
if user_variables is not None:
@@ -254,6 +265,7 @@ def chat(
254265
format=format,
255266
model_options=model_options,
256267
tool_calls=tool_calls,
268+
labels=labels,
257269
)
258270
parsed_assistant_message = result.parsed_repr
259271
assert isinstance(parsed_assistant_message, Message)
@@ -429,6 +441,7 @@ async def aact(
429441
model_options: dict | None = None,
430442
tool_calls: bool = False,
431443
silence_context_type_warning: bool = False,
444+
labels: Sequence[str] | None = None,
432445
) -> tuple[ModelOutputThunk, Context]: ...
433446

434447

@@ -445,6 +458,7 @@ async def aact(
445458
model_options: dict | None = None,
446459
tool_calls: bool = False,
447460
silence_context_type_warning: bool = False,
461+
labels: Sequence[str] | None = None,
448462
) -> SamplingResult: ...
449463

450464

@@ -475,7 +489,7 @@ async def aact(
475489
model_options: additional model options, which will upsert into the model/backend's defaults.
476490
tool_calls: if true, tool calling is enabled.
477491
silence_context_type_warning: if called directly from an asynchronous function, will log a warning if not using a SimpleContext
478-
labels: if provided, restrict generation to context nodes with matching labels
492+
labels: if provided, restrict generation to context nodes with matching types.
479493
480494
Returns:
481495
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
@@ -571,6 +585,7 @@ async def ainstruct(
571585
format: type[BaseModelSubclass] | None = None,
572586
model_options: dict | None = None,
573587
tool_calls: bool = False,
588+
labels: Sequence[str] | None = None,
574589
) -> tuple[ModelOutputThunk, Context]: ...
575590

576591

@@ -592,6 +607,7 @@ async def ainstruct(
592607
format: type[BaseModelSubclass] | None = None,
593608
model_options: dict | None = None,
594609
tool_calls: bool = False,
610+
labels: Sequence[str] | None = None,
595611
) -> SamplingResult: ...
596612

597613

@@ -612,6 +628,7 @@ async def ainstruct(
612628
format: type[BaseModelSubclass] | None = None,
613629
model_options: dict | None = None,
614630
tool_calls: bool = False,
631+
labels: Sequence[str] | None = None,
615632
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
616633
"""Generates from an instruction.
617634
@@ -631,6 +648,7 @@ async def ainstruct(
631648
model_options: Additional model options, which will upsert into the model/backend's defaults.
632649
tool_calls: If true, tool calling is enabled.
633650
images: A list of images to be used in the instruction or None if none.
651+
labels: if provided, restrict generation to context nodes with matching types.
634652
635653
Returns:
636654
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
@@ -663,6 +681,7 @@ async def ainstruct(
663681
format=format,
664682
model_options=model_options,
665683
tool_calls=tool_calls,
684+
labels=labels,
666685
) # type: ignore[call-overload]
667686

668687

@@ -677,6 +696,7 @@ async def achat(
677696
format: type[BaseModelSubclass] | None = None,
678697
model_options: dict | None = None,
679698
tool_calls: bool = False,
699+
labels: Sequence[str] | None = None,
680700
) -> tuple[Message, Context]:
681701
"""Sends a simple chat message and returns the response. Adds both messages to the Context."""
682702
if user_variables is not None:
@@ -696,6 +716,7 @@ async def achat(
696716
format=format,
697717
model_options=model_options,
698718
tool_calls=tool_calls,
719+
labels=labels,
699720
)
700721
parsed_assistant_message = result.parsed_repr
701722
assert isinstance(parsed_assistant_message, Message)

test/stdlib_basics/test_base_context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,22 @@ def test_actions_for_available_tools():
6767
assert actions[i] == for_generation[i]
6868

6969

70+
def test_render_view_for_chat_context_with_labels():
71+
ctx = ChatContext(window_size=3)
72+
for i in range(5):
73+
ctx = ctx.add(CBlock(f"a {i}"), labels=[str(i // 2)])
74+
75+
# no labels
76+
assert len(ctx.as_list()) == 5, "Context size must be 5"
77+
assert len(ctx.view_for_generation()) == 3, "Render size must be 3"
78+
79+
# with explicit labels
80+
for labels, al_sz, vg_sz in [(None, 5, 3), ([str(0)], 2, 2), ([str(2)], 1, 1)]:
81+
assert len(ctx.as_list(labels=labels)) == al_sz, f"Context size must be {al_sz}"
82+
assert len(ctx.view_for_generation(labels=labels)) == vg_sz, (
83+
f"Render size must be {vg_sz}"
84+
)
85+
86+
7087
if __name__ == "__main__":
7188
pytest.main([__file__])

0 commit comments

Comments
 (0)