Skip to content

Commit d21fe64

Browse files
author
Maxwell Crouse [email protected]
committed
update formatting to ruff
1 parent a01f694 commit d21fe64

File tree

8 files changed

+35
-27
lines changed

8 files changed

+35
-27
lines changed

.pre-commit-config.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ repos:
1313
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
1414
files: '^(mellea|tests).*\.(py|ipynb)$'
1515

16-
# - repo: local
17-
# hooks:
18-
# - id: mypy
19-
# name: MyPy
20-
# entry: uv run --no-sync mypy mellea
21-
# pass_filenames: false
22-
# language: system
23-
# files: '\.py$'
16+
- repo: local
17+
hooks:
18+
- id: mypy
19+
name: MyPy
20+
entry: uv run --no-sync mypy mellea --no-namespace-packages
21+
pass_filenames: false
22+
language: system
23+
files: '\.py$'
2424

2525
- repo: https://github.com/astral-sh/uv-pre-commit
2626
rev: 0.7.8

mellea/backends/openai.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,8 @@ def generate_from_chat_context(
323323
action: Component | CBlock | None,
324324
ctx: Context,
325325
*,
326-
_format: (
327-
type[BaseModelSubclass] | None
328-
) = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
326+
_format: type[BaseModelSubclass]
327+
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
329328
model_options: dict | None = None,
330329
tool_calls: bool = False,
331330
labels: Sequence[str] | None = None,
@@ -360,9 +359,8 @@ def _generate_from_chat_context_alora(
360359
action: Component | CBlock,
361360
ctx: Context,
362361
*,
363-
_format: (
364-
type[BaseModelSubclass] | None
365-
) = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
362+
_format: type[BaseModelSubclass]
363+
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
366364
model_options: dict | None = None,
367365
) -> ModelOutputThunk:
368366
match action:
@@ -447,9 +445,8 @@ def _generate_from_chat_context_standard(
447445
action: Component | CBlock | None,
448446
ctx: Context,
449447
*,
450-
_format: (
451-
type[BaseModelSubclass] | None
452-
) = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
448+
_format: type[BaseModelSubclass]
449+
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
453450
model_options: dict | None = None,
454451
tool_calls: bool = False,
455452
labels: Sequence[str] | None = None,
@@ -719,11 +716,9 @@ def generate_from_raw(
719716
output._model_options = model_opts
720717
output._meta = {
721718
"oai_completion_response": response.model_dump(),
722-
"usage": (
723-
completion_response.usage.model_dump()
724-
if completion_response.usage
725-
else None
726-
),
719+
"usage": completion_response.usage.model_dump()
720+
if completion_response.usage
721+
else None,
727722
}
728723

729724
self.formatter.parse(action, output)

mellea/stdlib/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ class Context(abc.ABC):
407407
_data: Component | CBlock | None
408408
_is_root: bool
409409
_is_chat_context: bool = True
410-
_labels: set[str] = None
410+
_labels: set[str] | None = None
411411

412412
def __init__(self):
413413
"""Constructs a new root context with no content."""

mellea/stdlib/functional.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,8 @@ def validate(
269269
output: CBlock | None = None,
270270
format: type[BaseModelSubclass] | None = None,
271271
model_options: dict | None = None,
272-
generate_logs: (
273-
list[GenerateLog] | None
274-
) = None, # TODO: Can we get rid of gen logs here and in act?
272+
generate_logs: list[GenerateLog]
273+
| None = None, # TODO: Can we get rid of gen logs here and in act?
275274
input: CBlock | None = None,
276275
) -> list[ValidationResult]:
277276
"""Validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""

mellea/stdlib/sampling/best_of_n.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Best of N Sampling Strategy."""
22

3+
from collections.abc import Sequence
34
from copy import deepcopy
45

56
import tqdm
@@ -29,6 +30,7 @@ async def sample(
2930
model_options: dict | None = None,
3031
tool_calls: bool = False,
3132
show_progress: bool = True,
33+
labels: Sequence[str] | None = None,
3234
) -> SamplingResult:
3335
"""This method performs a sampling operation based on the given instruction.
3436
@@ -42,6 +44,7 @@ async def sample(
4244
model_options: model options to pass to the backend during generation / validation.
4345
tool_calls: True if tool calls should be used during this sampling strategy.
4446
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
47+
labels: if provided, restrict generation to context nodes with matching types.
4548
4649
Returns:
4750
SamplingResult: A result object indicating the success or failure of the sampling process.
@@ -114,6 +117,7 @@ async def sample(
114117
format=format,
115118
model_options=model_options,
116119
tool_calls=tool_calls,
120+
labels=labels,
117121
)
118122
sampled_results.append(result)
119123
sampled_actions.append(next_action)

mellea/stdlib/sampling/budget_forcing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Sampling Strategies for budget forcing generation."""
22

3+
from collections.abc import Sequence
34
from copy import deepcopy
45

56
import tqdm
67

78
from mellea.backends import Backend, BaseModelSubclass
89
from mellea.backends.ollama import OllamaModelBackend
910
from mellea.helpers.fancy_logger import FancyLogger
10-
from mellea.stdlib import funcs as mfuncs
11+
from mellea.stdlib import functional as mfuncs
1112
from mellea.stdlib.base import ModelOutputThunk
1213
from mellea.stdlib.requirement import Requirement, ValidationResult
1314
from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult
@@ -82,6 +83,7 @@ async def sample(
8283
model_options: dict | None = None,
8384
tool_calls: bool = False,
8485
show_progress: bool = True,
86+
labels: Sequence[str] | None = None,
8587
) -> SamplingResult:
8688
"""This method performs a sampling operation based on the given instruction.
8789
@@ -95,6 +97,7 @@ async def sample(
9597
model_options: model options to pass to the backend during generation / validation.
9698
tool_calls: True if tool calls should be used during this sampling strategy.
9799
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
100+
labels: if provided, restrict generation to context nodes with matching types.
98101
99102
Returns:
100103
SamplingResult: A result object indicating the success or failure of the sampling process.

mellea/stdlib/sampling/majority_voting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import asyncio
5+
from collections.abc import Sequence
56

67
import numpy as np
78
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
@@ -74,6 +75,7 @@ async def sample(
7475
model_options: dict | None = None,
7576
tool_calls: bool = False,
7677
show_progress: bool = True,
78+
labels: Sequence[str] | None = None,
7779
) -> SamplingResult:
7880
"""Samples using majority voting.
7981
@@ -87,6 +89,7 @@ async def sample(
8789
model_options: model options to pass to the backend during generation / validation.
8890
tool_calls: True if tool calls should be used during this sampling strategy.
8991
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
92+
labels: if provided, restrict generation to context nodes with matching types.
9093
9194
Returns:
9295
SamplingResult: A result object indicating the success or failure of the sampling process.
@@ -104,6 +107,7 @@ async def sample(
104107
model_options=model_options,
105108
tool_calls=tool_calls,
106109
show_progress=show_progress,
110+
labels=labels,
107111
)
108112
)
109113
tasks.append(task)

mellea/stdlib/sampling/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Base types for sampling."""
22

33
import abc
4+
from collections.abc import Sequence
45

56
from mellea.backends import Backend, BaseModelSubclass
67
from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk
@@ -95,6 +96,7 @@ async def sample(
9596
format: type[BaseModelSubclass] | None = None,
9697
model_options: dict | None = None,
9798
tool_calls: bool = False,
99+
labels: Sequence[str] | None = None,
98100
) -> SamplingResult:
99101
"""This method is the abstract method for sampling a given component.
100102
@@ -109,6 +111,7 @@ async def sample(
109111
format: output format for structured outputs.
110112
model_options: model options to pass to the backend during generation / validation.
111113
tool_calls: True if tool calls should be used during this sampling strategy.
114+
labels: if provided, restrict generation to context nodes with matching types.
112115
113116
Returns:
114117
SamplingResult: A result object indicating the success or failure of the sampling process.

0 commit comments

Comments
 (0)