Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
run: ollama pull llama3.2:1b

- name: Run Tests
run: uv run -m pytest -v test -n auto --dist loadscope
run: uv run -m pytest -v test
- name: Send failure message tests
if: failure() # This step will only run if a previous step failed
run: echo "Tests failed. Please verify that tests are working locally."
Expand Down
51 changes: 51 additions & 0 deletions docs/examples/context/contexts_with_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from mellea.backends.types import ModelOption
from mellea.stdlib.sampling.base import RejectionSamplingStrategy
from mellea.stdlib.session import start_session

# You can retrieve context information when using SamplingStrategies
# and validation.

m = start_session()

# We want the full SamplingResult.
res = m.instruct(
"Write a sentence.",
requirements=["be funny", "be formal", "start the sentence with the letter w"],
strategy=RejectionSamplingStrategy(loop_budget=3),
return_sampling_results=True,
)

print()
print("Printing result of `Writing a sentence`.")
print(f"Result: {res.success}")
print(f"Result Output: {res.result}")
print()

# We can also look at the context for the chosen result and
# any other results that weren't chosen.
# (This prompt tends to take 2 attempts. If it only takes one, try re-running it.)
print(f"Total Generation Attempts: {len(res.sample_generations)}")
print()

print(f"Getting index of another result.")
index = 0 # Just choose the first one.

print(
"If the below is the same output, try re-running this program to get multiple attempts."
)
print(f"Different attempted output: {res.sample_generations[index]}")
print()

# We can see the context that created this output.
gen_ctx = res.sample_contexts[index]
print(f"Previous step in generating this result was: {gen_ctx.previous_node.node_data}")
print()

# We can also see what the validation context looked like.
req, val_result = res.sample_validations[index][0]
print(
f"Getting context when evaluating the above output against Req({req.description})."
)
val_ctx = val_result.context

print(f"Output of the validation for this requirement: {val_ctx.node_data}")
File renamed without changes.
10 changes: 9 additions & 1 deletion mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def generate_from_context(
mot = self._generate_from_context_alora(
action, ctx, _format=format, model_options=model_opts
)
return mot, ctx.add(mot)
return mot, ctx.add(action).add(mot)
else:
mot = self._generate_from_context_standard(
action,
Expand Down Expand Up @@ -512,6 +512,14 @@ def generate_from_raw(
"The raw endpoint does not support tool calling at the moment."
)

if self._model.device.type == "mps":
# TODO: Remove this when we are able to update the torch package.
# Test this by ensuring all outputs from this call are populated when running on mps.
# https://github.com/pytorch/pytorch/pull/157727
FancyLogger.get_logger().warning(
"utilizing device mps with a `generate_from_raw` request; you may see issues when submitting batches of prompts to a huggingface backend; ensure all ModelOutputThunks have non-empty values."
)

Comment on lines +519 to +522
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's what this was!! I was having issues when I was running hf tests for this but it disappeared when I stepped into the while debugging. Thanks for adding this!

model_opts = self._simplify_and_merge(model_options)
seed = model_opts.get(ModelOption.SEED, None)
if seed is not None:
Expand Down
7 changes: 3 additions & 4 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,16 @@ def _generate_from_chat_context_standard(
[OpenAIBackend.message_to_openai_message(m) for m in messages]
)

extra_params: dict[str, Any] = {}
if _format is not None:
response_format = {
extra_params["response_format"] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for the additional abstraction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with avi; response_format = None is better if the old value causes errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response_format = None sometimes causes issues as well with some backends (at least with the OpenAI backend I believe it used to). It's best to just not pass a response_format parameter if possible.

"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
else:
response_format = {"type": "text"}

thinking = model_opts.get(ModelOption.THINKING, None)
if type(thinking) is bool and thinking:
Expand All @@ -304,9 +303,9 @@ def _generate_from_chat_context_standard(
model=self._model_id,
messages=conversation,
tools=formatted_tools,
response_format=response_format,
reasoning_effort=thinking, # type: ignore
drop_params=True, # See note in `_make_backend_specific_and_remove`.
**extra_params,
**model_specific_options,
)

Expand Down
7 changes: 3 additions & 4 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,17 +465,16 @@ def _generate_from_chat_context_standard(
conversation.append({"role": "system", "content": system_prompt})
conversation.extend([self.message_to_openai_message(m) for m in messages])

extra_params: dict[str, Any] = {}
if _format is not None:
response_format = {
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
else:
response_format = {"type": "text"}

# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
Expand Down Expand Up @@ -507,9 +506,9 @@ def _generate_from_chat_context_standard(
model=self._hf_model_id,
messages=conversation, # type: ignore
reasoning_effort=thinking, # type: ignore
response_format=response_format, # type: ignore
tools=formatted_tools if use_tools else None, # type: ignore
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
**extra_params,
**self._make_backend_specific_and_remove(
model_opts, is_chat_context=ctx.is_chat_context
),
Expand Down
37 changes: 34 additions & 3 deletions mellea/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
convert_tools_to_json,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue
from mellea.helpers.event_loop_helper import _run_async_in_thread
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
CBlock,
Expand Down Expand Up @@ -80,6 +81,14 @@ def __init__(
formatter (Formatter): A mechanism for turning `stdlib` stuff into strings. Experimental Span-based models should use `mellea.backends.span.*` backends.
model_options (Optional[dict]): Default model options.
"""
if os.environ.get("VLLM_USE_V1", -1) != "0":
FancyLogger.get_logger().error(
"Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`."
)
raise ValueError(
"Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`."
)

formatter = (
formatter if formatter is not None else TemplateFormatter(model_id=model_id)
)
Expand Down Expand Up @@ -140,7 +149,7 @@ def __init__(
while True:
retry += 1
try:
self._model = vllm.AsyncLLMEngine.from_engine_args(
self._underlying_model = vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(model=self._hf_model_id, **engine_args)
)
break
Expand Down Expand Up @@ -192,6 +201,9 @@ def __init__(
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
)

# Keep track of the event loop the engine was instantiated in.
self._event_loop = get_current_event_loop()

self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
self._hf_model_id
) # type:ignore
Expand All @@ -205,6 +217,24 @@ def __init__(
"outlines.models.vllm"
).adapt_tokenizer(self._tokenizer)

@property
def _model(self) -> vllm.AsyncLLMEngine:
"""Use model when making generation requests."""
el = get_current_event_loop()

# vLLM attaches itself to the event loop that is running when instantiated /
# the first generate request is made. Thankfully, they provide helpers to
# reset that. We do that here if the event loop changes.

# Most of the time, this should be a no-op. The event loop will only change
# if switching between async and sync calls.
if el != self._event_loop:
self._underlying_model.shutdown_background_loop()
self._underlying_model.start_background_loop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They call that a background_loop but it's not an event loop, it's actually a Future. Even the _background_loop_unshielded is a Task object.

I think it's fine to manage the reference to the event loop on our side. We only ever have the one AsyncLLMEngine per LocalVLLMBackend so there shouldn't be issues with us tracking it this way. Happy to change it if it causes issues later on.

self._event_loop = el

return self._underlying_model

def generate_from_context(
self,
action: Component | CBlock,
Expand Down Expand Up @@ -447,7 +477,8 @@ async def generate_all(prompts):
tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)]
return await asyncio.gather(*tasks)

decoded_results = asyncio.run(generate_all(prompts))
# Allow calling this from async functions.
decoded_results = _run_async_in_thread(generate_all(prompts))

results = [ModelOutputThunk(value=text) for text in decoded_results]

Expand Down
14 changes: 12 additions & 2 deletions mellea/stdlib/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
reason: str | None = None,
score: float | None = None,
thunk: ModelOutputThunk | None = None,
context: Context | None = None,
):
"""The result of a requirement's validation.

Expand All @@ -57,11 +58,13 @@ def __init__(
reason: a reason for the result
score: if your validator gives you a score back, you can add this as metadata
thunk: if your validator utilizes a backend to generate a response, the ModelOutputThunk returned from that request
context: if your validator utilizes a backend to generate a response, the context associated with that response
"""
self._result = result
self._reason = reason
self._score = score
self._thunk = thunk
self._context = context

@property
def reason(self) -> str | None:
Expand All @@ -78,6 +81,11 @@ def thunk(self) -> ModelOutputThunk | None:
"""The ModelOutputThunk associated with the validation func if an llm was used to generate the final result."""
return self._thunk

@property
def context(self) -> Context | None:
"""The context associated with validation if a backend was used to generate the final result."""
return self._context

def as_bool(self) -> bool:
"""Return a boolean value based on the result."""
return self._result
Expand Down Expand Up @@ -140,7 +148,7 @@ async def validate(
# and its template gets populated with the output correctly.
req_copy = copy(self)
req_copy._output = last_output.value
llm_as_a_judge_result, _ = backend.generate_from_context(
llm_as_a_judge_result, val_ctx = backend.generate_from_context(
req_copy, ctx, format=format, model_options=model_options
)
await llm_as_a_judge_result.avalue()
Expand All @@ -149,6 +157,7 @@ async def validate(
result=self.output_to_bool(llm_as_a_judge_result),
reason=llm_as_a_judge_result.value,
thunk=llm_as_a_judge_result,
context=val_ctx,
)

def parts(self):
Expand Down Expand Up @@ -252,7 +261,7 @@ async def validate(
# and its template gets populated with the output correctly.
req_copy = copy(self)
req_copy._output = last_output.value
llm_as_a_judge_result, _ = backend.generate_from_context(
llm_as_a_judge_result, val_ctx = backend.generate_from_context(
req_copy, ctx, format=format, model_options=model_options
)
await llm_as_a_judge_result.avalue()
Expand All @@ -263,6 +272,7 @@ async def validate(
reason=llm_as_a_judge_result.value,
score=1 if result else 0,
thunk=llm_as_a_judge_result,
context=val_ctx,
)


Expand Down
4 changes: 2 additions & 2 deletions mellea/stdlib/safety/guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ async def validate(
# Use a CBlock for HuggingFace - it won't be added as a message
action = CBlock("") # type: ignore

mot, _ = self._backend.generate_from_context(
mot, val_ctx = self._backend.generate_from_context(
action, gctx, model_options=guardian_options
)
await mot.avalue()
Expand All @@ -337,5 +337,5 @@ async def validate(
reason_parts.append(f"Reasoning: {trace}")

return ValidationResult(
result=is_safe, reason="; ".join(reason_parts), thunk=mot
result=is_safe, reason="; ".join(reason_parts), thunk=mot, context=val_ctx
)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ vllm = [
"numpy<2.0.0", # patching incorrect dependencies in vllm and outlines.
# see https://github.com/vllm-project/vllm/issues/5587
"outlines-core==0.1.26",
"vllm", # intentionally un-versioned, expecting a minor update. coutlines-core version should be enough to specify it
"vllm>=0.9.1",
]

litellm = [
Expand Down Expand Up @@ -112,7 +112,6 @@ dev = [
"pytest-asyncio",
"mypy>=1.17.0",
"python-semantic-release~=7.32",
"pytest-xdist>=3.8.0",
]

notebook = [
Expand Down
13 changes: 9 additions & 4 deletions test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mellea.backends.formatter import TemplateFormatter
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.types import ModelOption
from mellea.stdlib.base import CBlock, ChatContext, SimpleContext
from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk, SimpleContext
from mellea.stdlib.requirement import (
ALoraRequirement,
LLMaJRequirement,
Expand Down Expand Up @@ -117,6 +117,14 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
val_result = validation_outputs[0]
assert isinstance(val_result, ValidationResult)
assert str(val_result.reason) in ["Y", "N"]

# Ensure the ValidationResult has its thunk and context set. Ensure the context has
# the correct actions / results in it.
assert isinstance(val_result.context, Context)
assert isinstance(val_result.thunk, ModelOutputThunk)
assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement)
assert val_result.context.node_data is val_result.thunk

backend.default_to_constraint_checking_alora = True


Expand Down Expand Up @@ -150,7 +158,6 @@ def test_multiturn(session):
"Take the result of the previous sum and find the corresponding letter in the greek alphabet.",
model_options={ModelOption.MAX_NEW_TOKENS: 300},
)
assert "β" in str(beta).lower()
words = session.instruct("Now list five English words that start with that letter.")
print(words)

Expand Down Expand Up @@ -193,7 +200,6 @@ class Email(pydantic.BaseModel):
"The email address should be at example.com"
)


@pytest.mark.qualitative
def test_generate_from_raw(session):
prompts = [
Expand All @@ -210,7 +216,6 @@ def test_generate_from_raw(session):

assert len(results) == len(prompts)


@pytest.mark.qualitative
def test_generate_from_raw_with_format(session):
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
Expand Down
Loading
Loading