-
Notifications
You must be signed in to change notification settings - Fork 53
fix: some minor fixes #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e0ddfe2
70bac6a
b0bd8ee
f2c067c
e521377
7d27532
70d92f1
fa4c4c6
b8582e6
46bc909
692b395
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from mellea.backends.types import ModelOption | ||
| from mellea.stdlib.sampling.base import RejectionSamplingStrategy | ||
| from mellea.stdlib.session import start_session | ||
|
|
||
| # You can retrieve context information when using SamplingStrategies | ||
| # and validation. | ||
|
|
||
| m = start_session() | ||
|
|
||
| # We want the full SamplingResult. | ||
| res = m.instruct( | ||
| "Write a sentence.", | ||
| requirements=["be funny", "be formal", "start the sentence with the letter w"], | ||
| strategy=RejectionSamplingStrategy(loop_budget=3), | ||
| return_sampling_results=True, | ||
| ) | ||
|
|
||
| print() | ||
| print("Printing result of `Writing a sentence`.") | ||
| print(f"Result: {res.success}") | ||
| print(f"Result Output: {res.result}") | ||
| print() | ||
|
|
||
| # We can also look at the context for the chosen result and | ||
| # any other results that weren't chosen. | ||
| # (This prompt tends to take 2 attempts. If it only takes one, try re-running it.) | ||
| print(f"Total Generation Attempts: {len(res.sample_generations)}") | ||
| print() | ||
|
|
||
| print(f"Getting index of another result.") | ||
| index = 0 # Just choose the first one. | ||
|
|
||
| print( | ||
| "If the below is the same output, try re-running this program to get multiple attempts." | ||
| ) | ||
| print(f"Different attempted output: {res.sample_generations[index]}") | ||
| print() | ||
|
|
||
| # We can see the context that created this output. | ||
| gen_ctx = res.sample_contexts[index] | ||
| print(f"Previous step in generating this result was: {gen_ctx.previous_node.node_data}") | ||
| print() | ||
|
|
||
| # We can also see what the validation context looked like. | ||
| req, val_result = res.sample_validations[index][0] | ||
| print( | ||
| f"Getting context when evaluating the above output against Req({req.description})." | ||
| ) | ||
| val_ctx = val_result.context | ||
|
|
||
| print(f"Output of the validation for this requirement: {val_ctx.node_data}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -270,17 +270,16 @@ def _generate_from_chat_context_standard( | |
| [OpenAIBackend.message_to_openai_message(m) for m in messages] | ||
| ) | ||
|
|
||
| extra_params: dict[str, Any] = {} | ||
| if _format is not None: | ||
| response_format = { | ||
| extra_params["response_format"] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason for the additional abstraction?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with avi;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| "type": "json_schema", | ||
| "json_schema": { | ||
| "name": _format.__name__, | ||
| "schema": _format.model_json_schema(), | ||
| "strict": True, | ||
| }, | ||
| } | ||
| else: | ||
| response_format = {"type": "text"} | ||
|
|
||
| thinking = model_opts.get(ModelOption.THINKING, None) | ||
| if type(thinking) is bool and thinking: | ||
|
|
@@ -304,9 +303,9 @@ def _generate_from_chat_context_standard( | |
| model=self._model_id, | ||
| messages=conversation, | ||
| tools=formatted_tools, | ||
| response_format=response_format, | ||
| reasoning_effort=thinking, # type: ignore | ||
| drop_params=True, # See note in `_make_backend_specific_and_remove`. | ||
| **extra_params, | ||
| **model_specific_options, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,8 @@ | |
| convert_tools_to_json, | ||
| ) | ||
| from mellea.backends.types import ModelOption | ||
| from mellea.helpers.async_helpers import send_to_queue | ||
| from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue | ||
| from mellea.helpers.event_loop_helper import _run_async_in_thread | ||
| from mellea.helpers.fancy_logger import FancyLogger | ||
| from mellea.stdlib.base import ( | ||
| CBlock, | ||
|
|
@@ -80,6 +81,14 @@ def __init__( | |
| formatter (Formatter): A mechanism for turning `stdlib` stuff into strings. Experimental Span-based models should use `mellea.backends.span.*` backends. | ||
| model_options (Optional[dict]): Default model options. | ||
| """ | ||
| if os.environ.get("VLLM_USE_V1", -1) != "0": | ||
| FancyLogger.get_logger().error( | ||
| "Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`." | ||
| ) | ||
| raise ValueError( | ||
| "Mellea LocalVLLMBackend doesn't support VLLM V1. Must `export VLLM_USE_V1=0`." | ||
| ) | ||
|
|
||
| formatter = ( | ||
| formatter if formatter is not None else TemplateFormatter(model_id=model_id) | ||
| ) | ||
|
|
@@ -140,7 +149,7 @@ def __init__( | |
| while True: | ||
| retry += 1 | ||
| try: | ||
| self._model = vllm.AsyncLLMEngine.from_engine_args( | ||
| self._underlying_model = vllm.AsyncLLMEngine.from_engine_args( | ||
| vllm.AsyncEngineArgs(model=self._hf_model_id, **engine_args) | ||
| ) | ||
| break | ||
|
|
@@ -192,6 +201,9 @@ def __init__( | |
| f"max_num_seqs: {engine_args['max_num_seqs']}\n" | ||
| ) | ||
|
|
||
| # Keep track of the event loop the engine was instantiated in. | ||
| self._event_loop = get_current_event_loop() | ||
|
|
||
| self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( | ||
| self._hf_model_id | ||
| ) # type:ignore | ||
|
|
@@ -205,6 +217,24 @@ def __init__( | |
| "outlines.models.vllm" | ||
| ).adapt_tokenizer(self._tokenizer) | ||
|
|
||
| @property | ||
| def _model(self) -> vllm.AsyncLLMEngine: | ||
| """Use model when making generation requests.""" | ||
| el = get_current_event_loop() | ||
|
|
||
| # vLLM attaches itself to the event loop that is running when instantiated / | ||
| # the first generate request is made. Thankfully, they provide helpers to | ||
| # reset that. We do that here if the event loop changes. | ||
|
|
||
| # Most of the time, this should be a no-op. The event loop will only change | ||
| # if switching between async and sync calls. | ||
| if el != self._event_loop: | ||
| self._underlying_model.shutdown_background_loop() | ||
| self._underlying_model.start_background_loop() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They call that a I think it's fine to manage the reference to the event loop on our side. We only ever have the one |
||
| self._event_loop = el | ||
|
|
||
| return self._underlying_model | ||
|
|
||
| def generate_from_context( | ||
| self, | ||
| action: Component | CBlock, | ||
|
|
@@ -447,7 +477,8 @@ async def generate_all(prompts): | |
| tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)] | ||
| return await asyncio.gather(*tasks) | ||
|
|
||
| decoded_results = asyncio.run(generate_all(prompts)) | ||
| # Allow calling this from async functions. | ||
| decoded_results = _run_async_in_thread(generate_all(prompts)) | ||
|
|
||
| results = [ModelOutputThunk(value=text) for text in decoded_results] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that's what this was!! I was having issues when I was running hf tests for this but it disappeared when I stepped into the while debugging. Thanks for adding this!