Skip to content

Commit c3854f4

Browse files
committed
fix: mypy issues
1 parent 849a7ed commit c3854f4

File tree

1 file changed

+54
-32
lines changed

1 file changed

+54
-32
lines changed

mellea/stdlib/mellea_functions.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
21
# TODO: JAL rename this file...
32
from __future__ import annotations
43

54
import asyncio
65
from collections.abc import Coroutine
76
from concurrent.futures import ThreadPoolExecutor
8-
from typing import Any, Literal, Tuple, TypeVar, overload
7+
from typing import Any, Literal, TypeVar, overload
98

109
from PIL import Image as PILImage
1110

@@ -46,7 +45,8 @@ def act(
4645
format: type[BaseModelSubclass] | None = None,
4746
model_options: dict | None = None,
4847
tool_calls: bool = False,
49-
) -> Tuple[ModelOutputThunk, Context]: ...
48+
) -> tuple[ModelOutputThunk, Context]: ...
49+
5050

5151
@overload
5252
def act(
@@ -61,6 +61,7 @@ def act(
6161
tool_calls: bool = False,
6262
) -> SamplingResult: ...
6363

64+
6465
def act(
6566
action: Component,
6667
context: Context,
@@ -72,13 +73,13 @@ def act(
7273
format: type[BaseModelSubclass] | None = None,
7374
model_options: dict | None = None,
7475
tool_calls: bool = False,
75-
) -> Tuple[ModelOutputThunk, Context] | SamplingResult:
76+
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
7677
"""Runs a generic action, and adds both the action and the result to the context.
7778
7879
Args:
7980
action: the Component from which to generate.
8081
context: the context being used as a history from which to generate the response.
81-
backend: the backend used to generate the response.
82+
backend: the backend used to generate the response.
8283
requirements: used as additional requirements when a sampling strategy is provided.
8384
strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used.
8485
return_sampling_results: attach the (successful and failed) sampling attempts to the results.
@@ -106,6 +107,7 @@ def act(
106107

107108
return out
108109

110+
109111
async def _act(
110112
action: Component,
111113
context: Context,
@@ -117,13 +119,13 @@ async def _act(
117119
format: type[BaseModelSubclass] | None = None,
118120
model_options: dict | None = None,
119121
tool_calls: bool = False,
120-
) -> Tuple[ModelOutputThunk, Context] | SamplingResult:
122+
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
121123
"""Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context.
122124
123125
Args:
124126
action: the Component from which to generate.
125127
context: the context being used as a history from which to generate the response.
126-
backend: the backend used to generate the response.
128+
backend: the backend used to generate the response.
127129
requirements: used as additional requirements when a sampling strategy is provided
128130
strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used.
129131
return_sampling_results: attach the (successful and failed) sampling attempts to the results.
@@ -167,10 +169,10 @@ async def _act(
167169
requirements = []
168170

169171
sampling_result = await strategy.sample(
170-
action,
171-
context=context,
172-
backend=backend,
173-
requirements=requirements,
172+
action,
173+
context=context,
174+
backend=backend,
175+
requirements=requirements,
174176
validation_ctx=None,
175177
format=format,
176178
model_options=model_options,
@@ -179,9 +181,7 @@ async def _act(
179181

180182
assert sampling_result.sample_generations is not None
181183
for result in sampling_result.sample_generations:
182-
assert (
183-
result._generate_log is not None
184-
) # Cannot be None after generation.
184+
assert result._generate_log is not None # Cannot be None after generation.
185185
generate_logs.append(result._generate_log)
186186

187187
# TODO: JAL. Extract the context from the sampling result.
@@ -200,6 +200,7 @@ async def _act(
200200
else:
201201
return result, new_ctx
202202

203+
203204
@overload
204205
def instruct(
205206
description: str,
@@ -218,7 +219,8 @@ def instruct(
218219
format: type[BaseModelSubclass] | None = None,
219220
model_options: dict | None = None,
220221
tool_calls: bool = False,
221-
) -> Tuple[ModelOutputThunk, Context]: ...
222+
) -> tuple[ModelOutputThunk, Context]: ...
223+
222224

223225
@overload
224226
def instruct(
@@ -240,6 +242,7 @@ def instruct(
240242
tool_calls: bool = False,
241243
) -> SamplingResult: ...
242244

245+
243246
def instruct(
244247
description: str,
245248
context: Context,
@@ -257,7 +260,7 @@ def instruct(
257260
format: type[BaseModelSubclass] | None = None,
258261
model_options: dict | None = None,
259262
tool_calls: bool = False,
260-
) -> Tuple[ModelOutputThunk, Context] | SamplingResult:
263+
) -> tuple[ModelOutputThunk, Context] | SamplingResult:
261264
"""Generates from an instruction.
262265
263266
Args:
@@ -308,6 +311,7 @@ def instruct(
308311
tool_calls=tool_calls,
309312
) # type: ignore[call-overload]
310313

314+
311315
def chat(
312316
content: str,
313317
context: Context,
@@ -319,7 +323,7 @@ def chat(
319323
format: type[BaseModelSubclass] | None = None,
320324
model_options: dict | None = None,
321325
tool_calls: bool = False,
322-
) -> Tuple[Message, Context]:
326+
) -> tuple[Message, Context]:
323327
"""Sends a simple chat message and returns the response. Adds both messages to the Context."""
324328
if user_variables is not None:
325329
content_resolved = Instruction.apply_user_dict_from_jinja(
@@ -343,6 +347,7 @@ def chat(
343347

344348
return parsed_assistant_message, new_ctx
345349

350+
346351
def validate(
347352
reqs: Requirement | list[Requirement],
348353
context: Context,
@@ -351,7 +356,8 @@ def validate(
351356
output: CBlock | None = None,
352357
format: type[BaseModelSubclass] | None = None,
353358
model_options: dict | None = None,
354-
generate_logs: list[GenerateLog] | None = None, # TODO: Can we get rid of gen logs here and in act?
359+
generate_logs: list[GenerateLog]
360+
| None = None, # TODO: Can we get rid of gen logs here and in act?
355361
input: CBlock | None = None,
356362
) -> list[ValidationResult]:
357363
"""Validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""
@@ -367,12 +373,13 @@ def validate(
367373
model_options=model_options,
368374
generate_logs=generate_logs,
369375
input=input,
370-
),
376+
)
371377
)
372378

373379
# Wait for and return the result.
374380
return out
375381

382+
376383
async def _validate(
377384
reqs: Requirement | list[Requirement],
378385
context: Context,
@@ -383,7 +390,9 @@ async def _validate(
383390
model_options: dict | None = None,
384391
generate_logs: list[GenerateLog] | None = None,
385392
input: CBlock | None = None,
386-
) -> list[ValidationResult]: # TODO: JAL. We should think about returning the contexts as well.
393+
) -> list[
394+
ValidationResult
395+
]: # TODO: JAL. We should think about returning the contexts as well.
387396
"""Asynchronous version of .validate; validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""
388397
# Turn a solitary requirement in to a list of requirements, and then reqify if needed.
389398
reqs = [reqs] if not isinstance(reqs, list) else reqs
@@ -403,10 +412,7 @@ async def _validate(
403412

404413
for requirement in reqs:
405414
val_result_co = requirement.validate(
406-
backend,
407-
validation_target_ctx,
408-
format=format,
409-
model_options=model_options,
415+
backend, validation_target_ctx, format=format, model_options=model_options
410416
)
411417
# TODO: JAL. do we ever need the context of a requirement / validation result.
412418
coroutines.append(val_result_co)
@@ -432,6 +438,7 @@ async def _validate(
432438

433439
return rvs
434440

441+
435442
def query(
436443
obj: Any,
437444
query: str,
@@ -441,7 +448,7 @@ def query(
441448
format: type[BaseModelSubclass] | None = None,
442449
model_options: dict | None = None,
443450
tool_calls: bool = False,
444-
) -> Tuple[ModelOutputThunk, Context]:
451+
) -> tuple[ModelOutputThunk, Context]:
445452
"""Query method for retrieving information from an object.
446453
447454
Args:
@@ -463,10 +470,16 @@ def query(
463470
q = obj.get_query_object(query)
464471

465472
answer = act(
466-
q, context=context, backend=backend, format=format, model_options=model_options, tool_calls=tool_calls
473+
q,
474+
context=context,
475+
backend=backend,
476+
format=format,
477+
model_options=model_options,
478+
tool_calls=tool_calls,
467479
)
468480
return answer
469481

482+
470483
def transform(
471484
obj: Any,
472485
transformation: str,
@@ -475,7 +488,7 @@ def transform(
475488
*,
476489
format: type[BaseModelSubclass] | None = None,
477490
model_options: dict | None = None,
478-
) -> Tuple[ModelOutputThunk | Any, Context]:
491+
) -> tuple[ModelOutputThunk | Any, Context]:
479492
"""Transform method for creating a new object with the transformation applied.
480493
481494
Args:
@@ -498,7 +511,12 @@ def transform(
498511
# Check that your model / backend supports tool calling.
499512
# This might throw an error when tools are provided but can't be handled by one or the other.
500513
transformed, new_ctx = act(
501-
t, context=context, backend=backend, format=format, model_options=model_options, tool_calls=True
514+
t,
515+
context=context,
516+
backend=backend,
517+
format=format,
518+
model_options=model_options,
519+
tool_calls=True,
502520
)
503521

504522
tools = _call_tools(transformed, backend)
@@ -539,6 +557,7 @@ def transform(
539557

540558
return transformed, new_ctx
541559

560+
542561
def _parse_and_clean_image_args(
543562
images_: list[ImageBlock] | list[PILImage.Image] | None = None,
544563
) -> list[ImageBlock] | None:
@@ -563,6 +582,7 @@ def _parse_and_clean_image_args(
563582
images = None
564583
return images
565584

585+
566586
def _call_tools(result: ModelOutputThunk, backend: Backend) -> list[ToolMessage]:
567587
"""Call all the tools requested in a result's tool calls object.
568588
@@ -596,13 +616,17 @@ def _call_tools(result: ModelOutputThunk, backend: Backend) -> list[ToolMessage]
596616
)
597617
return outputs
598618

619+
599620
R = TypeVar("R")
621+
622+
600623
def _run_async_in_thread(co: Coroutine[Any, Any, R]) -> R:
601624
"""Runs the provided coroutine.
602-
625+
603626
Checks if an event loop is running in this thread. If one is running in this thread,
604627
we use a separate thread to run the async code. Otherwise, run the code using asyncio.run.
605628
"""
629+
606630
def run_async(co: Coroutine):
607631
"""Helper function to run the coroutine."""
608632
return asyncio.run(co)
@@ -611,7 +635,7 @@ def run_async(co: Coroutine):
611635
loop = None
612636
try:
613637
loop = asyncio.get_running_loop()
614-
except:
638+
except Exception:
615639
pass
616640

617641
if loop is None:
@@ -625,5 +649,3 @@ def run_async(co: Coroutine):
625649
out = future.result()
626650

627651
return out
628-
629-

0 commit comments

Comments
 (0)