Skip to content

Commit 2ab8b0d

Browse files
committed
feat: add warning for async with non-Simple contexts
1 parent cf96cdb commit 2ab8b0d

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

mellea/stdlib/funcs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def act(
103103
format=format,
104104
model_options=model_options,
105105
tool_calls=tool_calls,
106+
silence_context_type_warning=True, # We can safely silence this here since it's in a sync function.
106107
) # type: ignore[call-overload]
107108
# Mypy doesn't like the bool for return_sampling_results.
108109
)
@@ -425,6 +426,7 @@ async def aact(
425426
format: type[BaseModelSubclass] | None = None,
426427
model_options: dict | None = None,
427428
tool_calls: bool = False,
429+
silence_context_type_warning: bool = False,
428430
) -> tuple[ModelOutputThunk, Context]: ...
429431

430432

@@ -440,6 +442,7 @@ async def aact(
440442
format: type[BaseModelSubclass] | None = None,
441443
model_options: dict | None = None,
442444
tool_calls: bool = False,
445+
silence_context_type_warning: bool = False,
443446
) -> SamplingResult: ...
444447

445448

@@ -454,6 +457,7 @@ async def aact(
454457
format: type[BaseModelSubclass] | None = None,
455458
model_options: dict | None = None,
456459
tool_calls: bool = False,
460+
silence_context_type_warning: bool = False,
457461
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
458462
"""Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context.
459463
@@ -467,10 +471,18 @@ async def aact(
467471
format: if set, the BaseModel to use for constrained decoding.
468472
model_options: additional model options, which will upsert into the model/backend's defaults.
469473
tool_calls: if true, tool calling is enabled.
474+
silence_context_type_warning: if called directly from an asynchronous function, will log a warning if not using a SimpleContext
470475
471476
Returns:
472477
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
473478
"""
479+
480+
if not silence_context_type_warning and not isinstance(context, SimpleContext):
481+
FancyLogger().get_logger().warning(
482+
"Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests."
483+
"\nSee the async section of the tutorial: https://github.com/generative-computing/mellea/blob/main/docs/tutorial.md#chapter-12-asynchronicity"
484+
)
485+
474486
sampling_result: SamplingResult | None = None
475487
generate_logs: list[GenerateLog] = []
476488

0 commit comments

Comments
 (0)