Skip to content

Commit 7fa0891

Browse files
authored
fix: some minor fixes (#223)
* fix: enforce minimum vllm version * fix: remove tests that look for "β" * fix: remove default response_format from litellm and openai backends * fix: remove xdist from pytests * fix: fix vllm tests * fix: vllm async event loop * feat: add contexts to validation results * fix: add warning for mps with huggingface generate from raw * fix: remove .py from folder name * fix: remove pytest-xdist specific args * fix: add exception with vllm backend when env var not set
1 parent 54f13f4 commit 7fa0891

File tree

18 files changed

+157
-60
lines changed

18 files changed

+157
-60
lines changed

.github/workflows/quality.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
run: ollama pull llama3.2:1b
5050

5151
- name: Run Tests
52-
run: uv run -m pytest -v test -n auto --dist loadscope
52+
run: uv run -m pytest -v test
5353
- name: Send failure message tests
5454
if: failure() # This step will only run if a previous step failed
5555
run: echo "Tests failed. Please verify that tests are working locally."
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from mellea.backends.types import ModelOption
2+
from mellea.stdlib.sampling.base import RejectionSamplingStrategy
3+
from mellea.stdlib.session import start_session
4+
5+
# You can retrieve context information when using SamplingStrategies
6+
# and validation.
7+
8+
m = start_session()
9+
10+
# We want the full SamplingResult.
11+
res = m.instruct(
12+
"Write a sentence.",
13+
requirements=["be funny", "be formal", "start the sentence with the letter w"],
14+
strategy=RejectionSamplingStrategy(loop_budget=3),
15+
return_sampling_results=True,
16+
)
17+
18+
print()
19+
print("Printing result of `Writing a sentence`.")
20+
print(f"Result: {res.success}")
21+
print(f"Result Output: {res.result}")
22+
print()
23+
24+
# We can also look at the context for the chosen result and
25+
# any other results that weren't chosen.
26+
# (This prompt tends to take 2 attempts. If it only takes one, try re-running it.)
27+
print(f"Total Generation Attempts: {len(res.sample_generations)}")
28+
print()
29+
30+
print(f"Getting index of another result.")
31+
index = 0 # Just choose the first one.
32+
33+
print(
34+
"If the below is the same output, try re-running this program to get multiple attempts."
35+
)
36+
print(f"Different attempted output: {res.sample_generations[index]}")
37+
print()
38+
39+
# We can see the context that created this output.
40+
gen_ctx = res.sample_contexts[index]
41+
print(f"Previous step in generating this result was: {gen_ctx.previous_node.node_data}")
42+
print()
43+
44+
# We can also see what the validation context looked like.
45+
req, val_result = res.sample_validations[index][0]
46+
print(
47+
f"Getting context when evaluating the above output against Req({req.description})."
48+
)
49+
val_ctx = val_result.context
50+
51+
print(f"Output of the validation for this requirement: {val_ctx.node_data}")
File renamed without changes.

mellea/backends/huggingface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def generate_from_context(
206206
mot = self._generate_from_context_alora(
207207
action, ctx, _format=format, model_options=model_opts
208208
)
209-
return mot, ctx.add(mot)
209+
return mot, ctx.add(action).add(mot)
210210
else:
211211
mot = self._generate_from_context_standard(
212212
action,
@@ -512,6 +512,14 @@ def generate_from_raw(
512512
"The raw endpoint does not support tool calling at the moment."
513513
)
514514

515+
if self._model.device.type == "mps":
516+
# TODO: Remove this when we are able to update the torch package.
517+
# Test this by ensuring all outputs from this call are populated when running on mps.
518+
# https://github.com/pytorch/pytorch/pull/157727
519+
FancyLogger.get_logger().warning(
520+
"utilizing device mps with a `generate_from_raw` request; you may see issues when submitting batches of prompts to a huggingface backend; ensure all ModelOutputThunks have non-empty values."
521+
)
522+
515523
model_opts = self._simplify_and_merge(model_options)
516524
seed = model_opts.get(ModelOption.SEED, None)
517525
if seed is not None:

mellea/backends/litellm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,16 @@ def _generate_from_chat_context_standard(
270270
[OpenAIBackend.message_to_openai_message(m) for m in messages]
271271
)
272272

273+
extra_params: dict[str, Any] = {}
273274
if _format is not None:
274-
response_format = {
275+
extra_params["response_format"] = {
275276
"type": "json_schema",
276277
"json_schema": {
277278
"name": _format.__name__,
278279
"schema": _format.model_json_schema(),
279280
"strict": True,
280281
},
281282
}
282-
else:
283-
response_format = {"type": "text"}
284283

285284
thinking = model_opts.get(ModelOption.THINKING, None)
286285
if type(thinking) is bool and thinking:
@@ -304,9 +303,9 @@ def _generate_from_chat_context_standard(
304303
model=self._model_id,
305304
messages=conversation,
306305
tools=formatted_tools,
307-
response_format=response_format,
308306
reasoning_effort=thinking, # type: ignore
309307
drop_params=True, # See note in `_make_backend_specific_and_remove`.
308+
**extra_params,
310309
**model_specific_options,
311310
)
312311

mellea/backends/openai.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,17 +465,16 @@ def _generate_from_chat_context_standard(
465465
conversation.append({"role": "system", "content": system_prompt})
466466
conversation.extend([self.message_to_openai_message(m) for m in messages])
467467

468+
extra_params: dict[str, Any] = {}
468469
if _format is not None:
469-
response_format = {
470+
extra_params["response_format"] = {
470471
"type": "json_schema",
471472
"json_schema": {
472473
"name": _format.__name__,
473474
"schema": _format.model_json_schema(),
474475
"strict": True,
475476
},
476477
}
477-
else:
478-
response_format = {"type": "text"}
479478

480479
# Append tool call information if applicable.
481480
tools: dict[str, Callable] = dict()
@@ -507,9 +506,9 @@ def _generate_from_chat_context_standard(
507506
model=self._hf_model_id,
508507
messages=conversation, # type: ignore
509508
reasoning_effort=thinking, # type: ignore
510-
response_format=response_format, # type: ignore
511509
tools=formatted_tools if use_tools else None, # type: ignore
512510
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
511+
**extra_params,
513512
**self._make_backend_specific_and_remove(
514513
model_opts, is_chat_context=ctx.is_chat_context
515514
),

mellea/backends/vllm.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
convert_tools_to_json,
3636
)
3737
from mellea.backends.types import ModelOption
38-
from mellea.helpers.async_helpers import send_to_queue
38+
from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue
39+
from mellea.helpers.event_loop_helper import _run_async_in_thread
3940
from mellea.helpers.fancy_logger import FancyLogger
4041
from mellea.stdlib.base import (
4142
CBlock,
@@ -80,6 +81,14 @@ def __init__(
8081
formatter (Formatter): A mechanism for turning `stdlib` stuff into strings. Experimental Span-based models should use `mellea.backends.span.*` backends.
8182
model_options (Optional[dict]): Default model options.
8283
"""
84+
if os.environ.get("VLLM_USE_V1", -1) != "0":
85+
FancyLogger.get_logger().error(
86+
"Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`."
87+
)
88+
raise ValueError(
89+
"Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`."
90+
)
91+
8392
formatter = (
8493
formatter if formatter is not None else TemplateFormatter(model_id=model_id)
8594
)
@@ -140,7 +149,7 @@ def __init__(
140149
while True:
141150
retry += 1
142151
try:
143-
self._model = vllm.AsyncLLMEngine.from_engine_args(
152+
self._underlying_model = vllm.AsyncLLMEngine.from_engine_args(
144153
vllm.AsyncEngineArgs(model=self._hf_model_id, **engine_args)
145154
)
146155
break
@@ -192,6 +201,9 @@ def __init__(
192201
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
193202
)
194203

204+
# Keep track of the event loop the engine was instantiated in.
205+
self._event_loop = get_current_event_loop()
206+
195207
self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
196208
self._hf_model_id
197209
) # type:ignore
@@ -205,6 +217,24 @@ def __init__(
205217
"outlines.models.vllm"
206218
).adapt_tokenizer(self._tokenizer)
207219

220+
@property
221+
def _model(self) -> vllm.AsyncLLMEngine:
222+
"""Use model when making generation requests."""
223+
el = get_current_event_loop()
224+
225+
# vLLM attaches itself to the event loop that is running when instantiated /
226+
# the first generate request is made. Thankfully, they provide helpers to
227+
# reset that. We do that here if the event loop changes.
228+
229+
# Most of the time, this should be a no-op. The event loop will only change
230+
# if switching between async and sync calls.
231+
if el != self._event_loop:
232+
self._underlying_model.shutdown_background_loop()
233+
self._underlying_model.start_background_loop()
234+
self._event_loop = el
235+
236+
return self._underlying_model
237+
208238
def generate_from_context(
209239
self,
210240
action: Component | CBlock,
@@ -447,7 +477,8 @@ async def generate_all(prompts):
447477
tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)]
448478
return await asyncio.gather(*tasks)
449479

450-
decoded_results = asyncio.run(generate_all(prompts))
480+
# Allow calling this from async functions.
481+
decoded_results = _run_async_in_thread(generate_all(prompts))
451482

452483
results = [ModelOutputThunk(value=text) for text in decoded_results]
453484

mellea/stdlib/requirement.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
reason: str | None = None,
4848
score: float | None = None,
4949
thunk: ModelOutputThunk | None = None,
50+
context: Context | None = None,
5051
):
5152
"""The result of a requirement's validation.
5253
@@ -57,11 +58,13 @@ def __init__(
5758
reason: a reason for the result
5859
score: if your validator gives you a score back, you can add this as metadata
5960
thunk: if your validator utilizes a backend to generate a response, the ModelOutputThunk returned from that request
61+
context: if your validator utilizes a backend to generate a response, the context associated with that response
6062
"""
6163
self._result = result
6264
self._reason = reason
6365
self._score = score
6466
self._thunk = thunk
67+
self._context = context
6568

6669
@property
6770
def reason(self) -> str | None:
@@ -78,6 +81,11 @@ def thunk(self) -> ModelOutputThunk | None:
7881
"""The ModelOutputThunk associated with the validation func if an llm was used to generate the final result."""
7982
return self._thunk
8083

84+
@property
85+
def context(self) -> Context | None:
86+
"""The context associated with validation if a backend was used to generate the final result."""
87+
return self._context
88+
8189
def as_bool(self) -> bool:
8290
"""Return a boolean value based on the result."""
8391
return self._result
@@ -140,7 +148,7 @@ async def validate(
140148
# and its template gets populated with the output correctly.
141149
req_copy = copy(self)
142150
req_copy._output = last_output.value
143-
llm_as_a_judge_result, _ = backend.generate_from_context(
151+
llm_as_a_judge_result, val_ctx = backend.generate_from_context(
144152
req_copy, ctx, format=format, model_options=model_options
145153
)
146154
await llm_as_a_judge_result.avalue()
@@ -149,6 +157,7 @@ async def validate(
149157
result=self.output_to_bool(llm_as_a_judge_result),
150158
reason=llm_as_a_judge_result.value,
151159
thunk=llm_as_a_judge_result,
160+
context=val_ctx,
152161
)
153162

154163
def parts(self):
@@ -252,7 +261,7 @@ async def validate(
252261
# and its template gets populated with the output correctly.
253262
req_copy = copy(self)
254263
req_copy._output = last_output.value
255-
llm_as_a_judge_result, _ = backend.generate_from_context(
264+
llm_as_a_judge_result, val_ctx = backend.generate_from_context(
256265
req_copy, ctx, format=format, model_options=model_options
257266
)
258267
await llm_as_a_judge_result.avalue()
@@ -263,6 +272,7 @@ async def validate(
263272
reason=llm_as_a_judge_result.value,
264273
score=1 if result else 0,
265274
thunk=llm_as_a_judge_result,
275+
context=val_ctx,
266276
)
267277

268278

0 commit comments

Comments
 (0)