Skip to content

Commit e70d307

Browse files
jakeloroccocsbobby
andauthored
fix: pre-commit file selection (#243)
* fix: fix pre-commit for mypy and formatter Co-authored-by: Bobby <[email protected]> * fix: fix pre-commit for formatter and fix files --------- Co-authored-by: Bobby <[email protected]>
1 parent 3087051 commit e70d307

File tree

12 files changed

+126
-60
lines changed

12 files changed

+126
-60
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ repos:
77
- id: ruff-format
88
name: "Ruff formatter"
99
args: [--config=pyproject.toml]
10-
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
10+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
1111
- id: ruff
1212
name: "Ruff linter"
1313
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
14-
files: '^(mellea|tests).*\.(py|ipynb)$'
14+
files: '^(mellea).*\.(py|ipynb)$'
1515

1616
- repo: local
1717
hooks:
@@ -20,7 +20,7 @@ repos:
2020
entry: uv run --no-sync mypy mellea
2121
pass_filenames: false
2222
language: system
23-
files: '\.py$'
23+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
2424

2525
- repo: https://github.com/astral-sh/uv-pre-commit
2626
rev: 0.7.8

test/backends/test_huggingface.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,20 @@
1010
from mellea.backends.formatter import TemplateFormatter
1111
from mellea.backends.huggingface import LocalHFBackend
1212
from mellea.backends.types import ModelOption
13-
from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk,
14-
SimpleContext)
15-
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
16-
Requirement, ValidationResult,
17-
default_output_to_bool)
13+
from mellea.stdlib.base import (
14+
CBlock,
15+
ChatContext,
16+
Context,
17+
ModelOutputThunk,
18+
SimpleContext,
19+
)
20+
from mellea.stdlib.requirement import (
21+
ALoraRequirement,
22+
LLMaJRequirement,
23+
Requirement,
24+
ValidationResult,
25+
default_output_to_bool,
26+
)
1827

1928

2029
@pytest.fixture(scope="module")
@@ -40,6 +49,7 @@ def session(backend):
4049
yield session
4150
session.reset()
4251

52+
4353
@pytest.mark.qualitative
4454
def test_adapters(backend):
4555
assert len(backend._added_adapters.items()) > 0

test/backends/test_litellm_ollama.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def backend(gh_run: int):
2626
url = url.replace("127.0.0.1", "http://localhost")
2727

2828
return LiteLLMBackend(
29-
model_id=_MODEL_ID,
30-
base_url=url,
31-
model_options={"api_base": url},
29+
model_id=_MODEL_ID, base_url=url, model_options={"api_base": url}
3230
)
3331
else:
3432
return LiteLLMBackend(model_id=_MODEL_ID)
@@ -111,12 +109,11 @@ def test_litellm_ollama_instruct_options(session):
111109
ModelOption.SEED: 123,
112110
ModelOption.TEMPERATURE: 0.5,
113111
ModelOption.MAX_NEW_TOKENS: 100,
114-
115-
# Ollama thinking controls currently broken on Granite; see
112+
# Ollama thinking controls currently broken on Granite; see
116113
# https://github.com/ollama/ollama/issues/10983
117114
# TODO: Re-enable when this upstream bug gets fixed.
118-
#ModelOption.THINKING: True,
119-
#"reasoning_effort": True,
115+
# ModelOption.THINKING: True,
116+
# "reasoning_effort": True,
120117
"homer_simpson": "option should be kicked out",
121118
}
122119

@@ -144,6 +141,7 @@ def is_happy(text: str) -> bool:
144141
# should yield to true - but, of course, is model dependent
145142
assert h is True
146143

144+
147145
async def test_generate_from_raw(session):
148146
prompts = [
149147
"what is 1+1?",
@@ -157,7 +155,9 @@ async def test_generate_from_raw(session):
157155
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
158156
)
159157

160-
assert len(results) == 1, "ollama doesn't support batching; litellm should send a single message containing all prompts"
158+
assert len(results) == 1, (
159+
"ollama doesn't support batching; litellm should send a single message containing all prompts"
160+
)
161161
assert results[0].value is not None
162162

163163

test/backends/test_litellm_watsonx.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,15 @@ def test_multiple_sync_funcs(session):
4141

4242
@pytest.mark.qualitative
4343
async def test_generate_from_raw(session):
44-
prompts = [
45-
"what is 1+1?",
46-
"what is 2+2?",
47-
"what is 3+3?",
48-
"what is 4+2+2?",
49-
]
44+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"]
5045

5146
results = await session.backend.generate_from_raw(
5247
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
5348
)
5449

55-
assert len(results) == 1, "litellm converts a batch request for watsonx into a single message"
50+
assert len(results) == 1, (
51+
"litellm converts a batch request for watsonx into a single message"
52+
)
5653
assert results[0].value is not None
5754

5855

test/backends/test_openai_ollama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ async def test_generate_from_raw(m_session):
122122
actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx
123123
)
124124

125+
125126
# Default OpenAI implementation doesn't support structured outputs for the completions API.
126127
# def test_generate_from_raw_with_format(self):
127128
# prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

test/backends/test_openai_vllm/test_openai_vllm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
from mellea.backends.openai import OpenAIBackend
1212
from mellea.backends.types import ModelOption, _ServerType
1313
from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk
14-
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
15-
Requirement, req)
14+
from mellea.stdlib.requirement import (
15+
ALoraRequirement,
16+
LLMaJRequirement,
17+
Requirement,
18+
req,
19+
)
1620

1721
# The vllm tests are disabled by default, because we need a test environment with the vLLM server running.
1822
# We use an env var VLLM_TESTS_ENABLED to enable these tests.
@@ -138,8 +142,11 @@ class TestOpenAIALoraStuff:
138142
base_url="http://localhost:8000/v1",
139143
api_key="EMPTY",
140144
)
141-
backend.add_adapter(GraniteCommonAdapter("requirement_check",
142-
base_model_name=backend.base_model_name))
145+
backend.add_adapter(
146+
GraniteCommonAdapter(
147+
"requirement_check", base_model_name=backend.base_model_name
148+
)
149+
)
143150

144151
m = MelleaSession(backend, ctx=ChatContext())
145152

test/stdlib_basics/test_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@ def format_for_llm(self) -> str:
2626
c = _ClosuredComponent()
2727
assert len(c.parts()) == 0
2828

29+
2930
if __name__ == "__main__":
3031
pytest.main([__file__])

test/stdlib_basics/test_chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from mellea.stdlib.base import Document
44
from mellea.stdlib.chat import Message
55

6+
67
def test_message_with_docs():
78
doc = Document("I'm text!", "Im a title!")
89
msg = Message("user", "hello", documents=[doc])

test/stdlib_basics/test_genslot.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,27 @@
55
from mellea.backends.model_ids import META_LLAMA_3_2_1B
66
from mellea.backends.ollama import OllamaModelBackend
77
from mellea.stdlib.base import ChatContext, Context
8-
from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, PreconditionException, SyncGenerativeSlot
8+
from mellea.stdlib.genslot import (
9+
AsyncGenerativeSlot,
10+
GenerativeSlot,
11+
PreconditionException,
12+
SyncGenerativeSlot,
13+
)
914
from mellea.stdlib.requirement import Requirement, simple_validate
1015
from mellea.stdlib.sampling.base import RejectionSamplingStrategy
1116
from mellea.stdlib.session import MelleaSession
1217

18+
1319
@pytest.fixture(scope="module")
1420
def backend(gh_run: int):
1521
"""Shared backend."""
1622
if gh_run == 1:
1723
return OllamaModelBackend(
18-
model_id=META_LLAMA_3_2_1B.ollama_name, # type: ignore
24+
model_id=META_LLAMA_3_2_1B.ollama_name # type: ignore
1925
)
2026
else:
21-
return OllamaModelBackend(
22-
model_id="granite3.3:8b",
23-
)
27+
return OllamaModelBackend(model_id="granite3.3:8b")
28+
2429

2530
@generative
2631
def classify_sentiment(text: str) -> Literal["positive", "negative"]: ...
@@ -81,26 +86,66 @@ async def test_async_gen_slot(session):
8186
r1 = async_write_short_sentence(session, topic="cats")
8287
r2 = async_write_short_sentence(session, topic="dogs")
8388

84-
r3, c3 = await async_write_short_sentence(context=session.ctx, backend=session.backend, topic="fish")
89+
r3, c3 = await async_write_short_sentence(
90+
context=session.ctx, backend=session.backend, topic="fish"
91+
)
8592
results = await asyncio.gather(r1, r2)
8693

8794
assert isinstance(r3, str)
8895
assert isinstance(c3, Context)
8996
assert len(results) == 2
9097

98+
9199
@pytest.mark.parametrize(
92100
"arg_choices,kwarg_choices,errs",
93101
[
94102
pytest.param(["m"], ["func1", "func2", "func3"], False, id="session"),
95103
pytest.param(["context"], ["backend"], False, id="context and backend"),
96-
pytest.param(["backend"], ["func1", "func2", "func3"], True, id="backend without context"),
104+
pytest.param(
105+
["backend"], ["func1", "func2", "func3"], True, id="backend without context"
106+
),
97107
pytest.param(["m"], ["m"], True, id="duplicate arg and kwarg"),
98-
pytest.param(["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], [], True, id="original func args as positional args"),
99-
pytest.param([], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"),
100-
pytest.param([], ["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], False, id="all kwargs"),
101-
pytest.param([], ["func1", "m", "func2", "requirements", "func3"], False, id="interspersed kwargs"),
102-
pytest.param([], [], True, id="missing required args")
103-
]
108+
pytest.param(
109+
[
110+
"m",
111+
"precondition_requirements",
112+
"requirements",
113+
"strategy",
114+
"model_options",
115+
"func1",
116+
"func2",
117+
"func3",
118+
],
119+
[],
120+
True,
121+
id="original func args as positional args",
122+
),
123+
pytest.param(
124+
[], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"
125+
),
126+
pytest.param(
127+
[],
128+
[
129+
"m",
130+
"precondition_requirements",
131+
"requirements",
132+
"strategy",
133+
"model_options",
134+
"func1",
135+
"func2",
136+
"func3",
137+
],
138+
False,
139+
id="all kwargs",
140+
),
141+
pytest.param(
142+
[],
143+
["func1", "m", "func2", "requirements", "func3"],
144+
False,
145+
id="interspersed kwargs",
146+
),
147+
pytest.param([], [], True, id="missing required args"),
148+
],
104149
)
105150
def test_arg_extraction(backend, arg_choices, kwarg_choices, errs):
106151
"""Tests the internal extract_args_and_kwargs function.
@@ -156,35 +201,40 @@ def test_arg_extraction(backend, arg_choices, kwarg_choices, errs):
156201
except Exception as e:
157202
found_err = True
158203
err = e
159-
204+
160205
if errs:
161206
assert found_err, "expected an exception and got none"
162207
else:
163208
assert not found_err, f"got unexpected err: {err}"
164209

210+
165211
def test_disallowed_parameter_names():
166212
with pytest.raises(ValueError):
213+
167214
@generative
168-
def test(backend):
169-
...
215+
def test(backend): ...
216+
170217

171218
def test_precondition_failure(session):
172219
with pytest.raises(PreconditionException):
173220
classify_sentiment(
174221
m=session,
175222
text="hello",
176223
precondition_requirements=[
177-
Requirement("forced failure", validation_fn=simple_validate(lambda x: (False, "")))
178-
]
224+
Requirement(
225+
"forced failure",
226+
validation_fn=simple_validate(lambda x: (False, "")),
227+
)
228+
],
179229
)
180230

231+
181232
def test_requirement(session):
182233
classify_sentiment(
183-
m=session,
184-
text="hello",
185-
requirements=["req1", "req2", Requirement("req3")]
234+
m=session, text="hello", requirements=["req1", "req2", Requirement("req3")]
186235
)
187236

237+
188238
def test_with_no_args(session):
189239
@generative
190240
def generate_text() -> str:
@@ -193,5 +243,6 @@ def generate_text() -> str:
193243

194244
generate_text(m=session)
195245

246+
196247
if __name__ == "__main__":
197248
pytest.main([__file__])

test/stdlib_basics/test_reqlib_tools.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import pytest
22
from mellea.stdlib.reqlib.tools import _name2str
33

4+
45
def test_name2str():
56
"""Test handling when no Python code is present."""
7+
68
def test123():
79
pass
10+
811
assert _name2str(test123) == "test123"
912
assert _name2str("test1234") == "test1234"
10-

0 commit comments

Comments
 (0)