Skip to content

Commit 7a42b1c

Browse files
authored
Merge branch 'main' into test-based-eval
2 parents 0010187 + 6b2a527 commit 7a42b1c

File tree

19 files changed

+522
-128
lines changed

19 files changed

+522
-128
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ repos:
77
- id: ruff-format
88
name: "Ruff formatter"
99
args: [--config=pyproject.toml]
10-
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
10+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
1111
- id: ruff
1212
name: "Ruff linter"
1313
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
14-
files: '^(mellea|tests).*\.(py|ipynb)$'
14+
files: '^(mellea).*\.(py|ipynb)$'
1515

1616
- repo: local
1717
hooks:
@@ -20,7 +20,7 @@ repos:
2020
entry: uv run --no-sync mypy mellea
2121
pass_filenames: false
2222
language: system
23-
files: '\.py$'
23+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
2424

2525
- repo: https://github.com/astral-sh/uv-pre-commit
2626
rev: 0.7.8

mellea/backends/adapters/adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import abc
44
import pathlib
5-
from typing import Any, TypeVar
5+
from typing import TypeVar
66

77
import granite_common.intrinsics
88
import yaml
9-
from litellm import cast
109

1110
from mellea.backends import Backend
1211
from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata

mellea/backends/huggingface.py

Lines changed: 112 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import inspect
1414
import json
15+
import threading
1516
from collections.abc import Callable, Coroutine
1617
from copy import deepcopy
1718
from typing import TYPE_CHECKING, Any, cast
@@ -182,6 +183,9 @@ def __init__(
182183
self._added_adapters: dict[str, LocalHFAdapter] = {}
183184
self._loaded_adapters: dict[str, LocalHFAdapter] = {}
184185

186+
self._generation_lock = threading.Lock()
187+
"""Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""
188+
185189
async def generate_from_context(
186190
self,
187191
action: Component | CBlock,
@@ -245,12 +249,43 @@ async def generate_from_context(
245249
)
246250
return mot, ctx.add(action).add(mot)
247251

252+
def _generate_with_adapter_lock(
253+
self, adapter_name: str, generate_func: Callable, *args, **kwargs
254+
):
255+
"""Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
256+
with self._generation_lock:
257+
if adapter_name != "":
258+
self.load_adapter(adapter_name)
259+
self._model.set_adapter(adapter_name)
260+
else:
261+
try:
262+
# `._model.disable_adapters()` doesn't seem to actually disable them or
263+
# remove them from the model's list of `.active_adapters()`.
264+
self._model.set_adapter([])
265+
except ValueError as e:
266+
# If no weights have been loaded, the model will raise a ValueError:
267+
# `ValueError("No adapter loaded. Please load an adapter first.")`
268+
if "No adapter loaded" in str(e):
269+
pass
270+
else:
271+
raise e
272+
273+
_assert_correct_adapters(adapter_name, self._model)
274+
out = generate_func(*args, **kwargs)
275+
_assert_correct_adapters(adapter_name, self._model)
276+
return out
277+
248278
async def _generate_from_intrinsic(
249279
self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any]
250280
) -> ModelOutputThunk:
251281
if not ctx.is_chat_context:
252282
raise Exception("Does not yet support non-chat contexts.")
253283

284+
if len(model_options.items()) > 0:
285+
FancyLogger.get_logger().info(
286+
"passing in model options when generating with an adapter; some model options may be overwritten / ignored"
287+
)
288+
254289
linearized_ctx = ctx.view_for_generation()
255290
assert linearized_ctx is not None, (
256291
"If ctx.is_chat_context, then the context should be linearizable."
@@ -311,33 +346,33 @@ async def _generate_from_intrinsic(
311346
"messages": conversation,
312347
"extra_body": {"documents": docs},
313348
}
349+
350+
# Convert other parameters from Mellea proprietary format to standard format.
351+
for model_option in model_options:
352+
if model_option == ModelOption.TEMPERATURE:
353+
request_json["temperature"] = model_options[model_option]
354+
314355
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
315356

316357
# TODO: Handle caching here. granite_common doesn't tell us what changed,
317358
# so we will have to invalidate the cache on our side. This requires
318359
# us having specific caching for each Component/Message.
319360

320-
self.load_adapter(adapter.qualified_name)
321-
322-
# TODO: This modifies the underlying model. We should set a non-exclusive lock here.
323-
# It should allow generate requests with the same adapter to proceed. This logic also
324-
# needs to be added to the other generate functions.
325-
self._model.set_adapter(adapter.qualified_name)
326-
327361
generate_input, other_input = (
328362
granite_common.util.chat_completion_request_to_transformers_inputs(
329363
rewritten, self._tokenizer, self._model
330364
)
331365
)
332366

333-
chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = (
334-
asyncio.to_thread(
335-
granite_common.util.generate_with_transformers,
336-
self._tokenizer,
337-
self._model,
338-
generate_input,
339-
other_input,
340-
)
367+
chat_response = asyncio.to_thread(
368+
self._generate_with_adapter_lock,
369+
adapter.qualified_name,
370+
granite_common.util.generate_with_transformers,
371+
# Passed as args/kwargs to generate.
372+
self._tokenizer,
373+
self._model,
374+
generate_input,
375+
other_input,
341376
)
342377

343378
output = ModelOutputThunk(None)
@@ -490,7 +525,10 @@ async def _generate_from_context_standard(
490525
generate_options = self._filter_chat_template_only_options(model_options)
491526

492527
chat_response = asyncio.to_thread(
528+
self._generate_with_adapter_lock,
529+
"", # Empty for no adapters.
493530
self._model.generate, # type: ignore
531+
# Passed as args/kwargs to generate.
494532
input_ids,
495533
return_dict_in_generate=True,
496534
output_scores=True,
@@ -664,42 +702,41 @@ async def generate_from_raw(
664702
self._device
665703
)
666704

667-
if format is None:
668-
outputs = await asyncio.to_thread(
669-
self._model.generate, # type: ignore
670-
input_ids=inputs["input_ids"],
671-
attention_mask=inputs["attention_mask"],
672-
return_dict_in_generate=True,
673-
output_scores=True,
674-
**self._make_backend_specific_and_remove(model_opts),
675-
)
676-
else:
705+
format_kwargs = {}
706+
if format:
707+
# outlines.generate.json always parses the resulting json into a python dict.
708+
# We however want to keep it as a json string for later storing it in ModelOutputThunk
677709
schema: dict[str, Any] = format.model_json_schema()
678710
schema_json: str = json.dumps(schema)
679-
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
711+
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
680712
schema_json
681713
)
682714

683715
from outlines.models.transformers import TransformerTokenizer
684-
from outlines.processors import RegexLogitsProcessor
716+
from outlines.processors.structured import RegexLogitsProcessor
685717
from transformers import LogitsProcessorList
686718

687-
outputs = await asyncio.to_thread(
688-
self._model.generate, # type: ignore
689-
input_ids=inputs["input_ids"],
690-
attention_mask=inputs["attention_mask"],
691-
return_dict_in_generate=True,
692-
output_scores=True,
693-
logits_processor=LogitsProcessorList(
694-
[
695-
RegexLogitsProcessor(
696-
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
697-
)
698-
]
699-
),
700-
**self._make_backend_specific_and_remove(model_opts),
719+
format_kwargs["logits_processor"] = LogitsProcessorList(
720+
[
721+
RegexLogitsProcessor(
722+
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
723+
)
724+
]
701725
)
702726

727+
outputs = await asyncio.to_thread(
728+
self._generate_with_adapter_lock,
729+
"", # Empty for no adapter.
730+
self._model.generate, # type: ignore
731+
# Passed as args/kwargs to generate.
732+
input_ids=inputs["input_ids"],
733+
attention_mask=inputs["attention_mask"],
734+
return_dict_in_generate=True,
735+
output_scores=True,
736+
**self._make_backend_specific_and_remove(model_opts),
737+
**format_kwargs,
738+
)
739+
703740
sequences_to_decode = [
704741
sequence[inputs["input_ids"][i].size(0) :] # type: ignore
705742
for i, sequence in enumerate(outputs.sequences)
@@ -853,7 +890,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
853890
self._added_adapters[adapter.qualified_name] = adapter
854891

855892
def load_adapter(self, adapter_qualified_name: str):
856-
"""Loads the given adapter for the backend. Must have previously been added."""
893+
"""Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening."""
857894
adapter = self._added_adapters.get(adapter_qualified_name, None)
858895
if adapter is None:
859896
raise ValueError(
@@ -880,7 +917,7 @@ def load_adapter(self, adapter_qualified_name: str):
880917
# Loading an adapter activates it. We disable adapters immediately after.
881918
# Prefer this over `.disable_adapters()`; the disable function doesn't always
882919
# seem to work.
883-
self._model.set_adapter([])
920+
self._model.disable_adapters()
884921
self._loaded_adapters[adapter.qualified_name] = adapter
885922

886923
def unload_adapter(self, adapter_qualified_name: str):
@@ -906,6 +943,38 @@ def list_adapters(self) -> list[str]:
906943
return list(self._loaded_adapters.keys())
907944

908945

946+
def _assert_correct_adapters(expected_state: str, model: PreTrainedModel):
947+
"""When generating with a huggingface model, this can be used to ensure the correct adapters are active.
948+
949+
Args:
950+
expected_state: the current state of the lock
951+
model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
952+
"""
953+
try:
954+
active = model.active_adapters()
955+
956+
if expected_state == "":
957+
assert len(active) == 0, (
958+
f'no adapters should be active if expected state is "", got "{active[0]}"'
959+
)
960+
else:
961+
assert len(active) == 1, (
962+
f'one adapter should be active if expected state is "{expected_state}"'
963+
)
964+
assert active[0] == expected_state, (
965+
f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"'
966+
)
967+
except ValueError as e:
968+
# If no weights have been loaded, the model will raise a ValueError:
969+
# `ValueError("No adapter loaded. Please load an adapter first.")`
970+
if "No adapter loaded" in str(e):
971+
assert expected_state == "", (
972+
f'got no adapters loaded but expected state is "{expected_state}"'
973+
)
974+
else:
975+
raise e
976+
977+
909978
class HFProcessRewardModel(PRM, abc.ABC):
910979
"""A Process Reward Model that works with a huggingface backend."""
911980

mellea/backends/litellm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def __init__(
5454
base_url: str | None = "http://localhost:11434",
5555
model_options: dict | None = None,
5656
):
57-
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
57+
"""Initialize an OpenAI compatible backend using the [LiteLLM Python SDK](https://docs.litellm.ai/docs/#litellm-python-sdk).
5858
5959
Note: If getting `Unclosed client session`, set `export DISABLE_AIOHTTP_TRANSPORT=True` in your environment. See: https://github.com/BerriAI/litellm/issues/13251.
6060
6161
Args:
62-
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
62+
model_id : The LiteLLM model identifier; in most cases requires some combination of `<provider>/<model_creator>/<model_name>`. Make sure that all necessary credentials are in OS environment variables.
6363
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
6464
base_url : Base url for LLM API. Defaults to None.
6565
model_options : Generation options to pass to the LLM. Defaults to None.

mellea/backends/openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ async def _generate_from_intrinsic(
435435
"extra_body": {"documents": docs},
436436
}
437437

438+
# Convert other parameters from Mellea proprietary format to standard format.
439+
if model_options is not None:
440+
for model_option in model_options:
441+
if model_option == ModelOption.TEMPERATURE:
442+
request_json["temperature"] = model_options[model_option]
443+
438444
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
439445

440446
self.load_adapter(adapter.qualified_name)

mellea/stdlib/intrinsics/rag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AdapterType,
1010
GraniteCommonAdapter,
1111
)
12+
from mellea.backends.types import ModelOption
1213
from mellea.stdlib.base import ChatContext, Document
1314
from mellea.stdlib.chat import Message
1415
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
@@ -63,6 +64,7 @@ def _call_intrinsic(
6364
intrinsic,
6465
context,
6566
backend,
67+
model_options={ModelOption.TEMPERATURE: 0.0},
6668
# No rejection sampling, please
6769
strategy=None,
6870
)
@@ -277,7 +279,7 @@ def rewrite_answer_for_relevance(
277279
backend,
278280
kwargs={
279281
"answer_relevance_category": result_json["answer_relevance_category"],
280-
"answer_relevance_analysis": result_json["answer_relevance_category"],
282+
"answer_relevance_analysis": result_json["answer_relevance_analysis"],
281283
"correction_method": correction_method,
282284
},
283285
)

mellea/templates/prompts/default/LLMaJRequirement.jinja2

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)