Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mellea/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def generate_from_context(
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> ModelOutputThunk: # i.e., ContextDiff
"""Generates a model output from a context. May not mutate the context.
Expand Down
6 changes: 3 additions & 3 deletions mellea/backends/aloras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc

from mellea.stdlib.base import CBlock
from mellea.stdlib.base import CBlock, ModelOutputThunk


class Alora(abc.ABC):
Expand All @@ -24,8 +24,8 @@ def __init__(self, name: str):
self.name: str = name

@abc.abstractmethod
def generate_using_strings(self, *args, **kwargs) -> str:
"""Generates from the ALora using raw strings as the interface for both inputs and outputs.
def generate_using_strings(self, *args, **kwargs) -> ModelOutputThunk:
"""Generates from the ALora using raw strings as the interface for inputs. In most cases, must be run from a running event loop.

This has a generic signature because each aLoRA has different parameters depending on its functionality and how it gets called.
"""
Expand Down
220 changes: 116 additions & 104 deletions mellea/backends/aloras/huggingface/granite_aloras.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Huggingface implementations for IBM's "starter pack" of Activated LoRAs."""

import asyncio
import functools
from copy import deepcopy

import torch
from transformers.generation.utils import GenerateDecoderOnlyOutput

from mellea.backends.huggingface import HFAlora, HFAloraCacheInfo, LocalHFBackend
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import GenerateType, ModelOutputThunk


class HFConstraintAlora(HFAlora):
Expand Down Expand Up @@ -44,31 +49,99 @@ def __init__(
self._logger = FancyLogger.get_logger()

def generate_using_strings(
self, input: str, response: str, constraint: str, force_yn: bool = True
) -> str:
"""Generates a constraint response from the ALora."""
self,
input: str,
response: str,
constraint: str,
force_yn: bool = True,
stream: bool = False,
) -> ModelOutputThunk:
"""Generates a constraint response from the ALora. Must be run in a running event loop."""
assert self._backend.alora_model is not None
# Go ahead and do runtime type-checking because passing CBlocks into this function is a common error.
assert type(input) is str
assert type(response) is str
assert type(constraint) is str
self._backend.alora_model.set_adapter(self.name)
cache_hit = self._backend.cache_get(response)

if stream:
self._logger.warning(
"`HFConstraintAlora` cannot stream output; defaulting to non-streaming approach."
)

generate_kwargs = {}
if cache_hit:
self._logger.debug(
f"using cache for alora {self.__class__} and response '{response}'"
)
return self._generate_using_cache(cache_hit, constraint, force_yn)
generate_kwargs["past_key_values"] = deepcopy(cache_hit.kv_cache)
input_combined = self._generate_using_cache(cache_hit, constraint, force_yn)

else:
self._logger.debug(
f"not using cache for alora {self.__class__} and response '{response}'"
)
return self._generate_not_using_cache(input, response, constraint, force_yn)
input_combined = self._generate_not_using_cache(
input, response, constraint, force_yn
)

if not self._include_constraint_in_alora_offset:
alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
else:
# Get the constraint tokens separately so that we can calculate the alora offsets.
constraint_tokens = self._backend._tokenizer(
self._constraint_prompt.format(constraint), return_tensors="pt"
).to(self._backend._device)

alora_offsets = [
constraint_tokens["input_ids"].shape[1]
+ self._generation_prompt_tokens["input_ids"].shape[1]
- 2
]

chat_response = asyncio.to_thread(
self._backend.alora_model.generate,
input_combined["input_ids"].to(self._backend._device),
attention_mask=input_combined["attention_mask"].to(self._backend._device),
max_new_tokens=1,
return_dict_in_generate=True,
alora_offsets=alora_offsets,
output_scores=True,
**generate_kwargs,
)

output = ModelOutputThunk(None)
output._meta["alora_name"] = self.name

output._process = functools.partial(
processing,
backend=self._backend,
force_yn=force_yn,
gen_prompt=self._generation_prompt,
)
output._post_process = functools.partial(post_processing, backend=self._backend)

try:
# To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine.
# We can also support synchronous calls by adding a flag and changing this ._generate function.

# This function should always be called from a running event loop so we don't have to worry about
# scheduling the task to a specific event loop here.
output._generate = asyncio.create_task(
send_to_queue(chat_response, output._async_queue) # type: ignore
)
output._generate_type = GenerateType.ASYNC
except RuntimeError as e:
# Most likely cause is running this function without an event loop present.
raise e

return output

def _generate_using_cache(
self, cache_hit: HFAloraCacheInfo, constraint: str, force_yn: bool
) -> str:
assert self._backend.alora_model is not None
) -> dict:
"""Returns the input object used for generation."""

# Must tokenize the constraint here since the requirement isn't known at initialization.
constraint_tokens = self._backend._tokenizer(
Expand All @@ -94,62 +167,16 @@ def _generate_using_cache(
),
}

if not self._include_constraint_in_alora_offset:
alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
else:
alora_offsets = [
constraint_tokens["input_ids"].shape[1]
+ self._generation_prompt_tokens["input_ids"].shape[1]
- 2
]
self._logger.debug(
f"Prompt for cached aLoRA({self.name}):\n {self._backend._tokenizer.decode(input_combined['input_ids'][0])}"
)

if force_yn:
output = self._backend.alora_model.generate(
input_combined["input_ids"].to(self._backend._device),
attention_mask=input_combined["attention_mask"].to(
self._backend._device
),
max_new_tokens=1,
past_key_values=deepcopy(cache_hit.kv_cache),
return_dict_in_generate=True,
alora_offsets=alora_offsets,
output_scores=True,
)
last_logits = output.scores[-1].squeeze(0)
token_Y = self._backend._tokenizer("Y", add_special_tokens=False)[
"input_ids"
][0]
token_N = self._backend._tokenizer("N", add_special_tokens=False)[
"input_ids"
][0]
logit_Y = last_logits[token_Y].item()
logit_N = last_logits[token_N].item()
return "Y" if logit_Y > logit_N else "N"
else:
output = self._backend.alora_model.generate(
input_combined["input_ids"].to(self._backend._device),
attention_mask=input_combined["attention_mask"].to(
self._backend._device
),
max_new_tokens=1,
past_key_values=deepcopy(cache_hit.kv_cache),
return_dict_in_generate=True,
alora_offsets=alora_offsets,
)
output_text = self._backend._tokenizer.decode(output.sequences[0])
assert output_text[-1] in ["Y", "N"], (
f"The constraint model card states the the Requirement Checker model will respond with 'Y' or 'N', but found: {output_text}"
)
constraint_satisfied = output_text.split(self._generation_prompt)[-1]
return constraint_satisfied[0]
return input_combined

def _generate_not_using_cache(
self, input: str, response: str, constraint: str, force_yn: bool
) -> str:
assert self._backend.alora_model is not None
) -> dict:
"""Returns the input object used for generation."""

# Params aren't needed when just getting the backend args.
backend_model_opts = self._backend._simplify_and_merge(None)
Expand Down Expand Up @@ -185,58 +212,43 @@ def _generate_not_using_cache(
),
}

if not self._include_constraint_in_alora_offset:
alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1]
else:
# Get the constraint tokens separately so that we can calculate the alora offsets.
constraint_tokens = self._backend._tokenizer(
self._constraint_prompt.format(constraint), return_tensors="pt"
).to(self._backend._device)

alora_offsets = [
constraint_tokens["input_ids"].shape[1]
+ self._generation_prompt_tokens["input_ids"].shape[1]
- 2
]

self._logger.debug(
f"Prompt for non-cached aLoRA({self.name}):\n{self._backend._tokenizer.decode(input_combined['input_ids'][0])}"
)

if force_yn:
output = self._backend.alora_model.generate(
input_combined["input_ids"].to(self._backend._device),
attention_mask=input_combined["attention_mask"].to(
self._backend._device
),
max_new_tokens=1,
return_dict_in_generate=True,
alora_offsets=alora_offsets,
output_scores=True,
)
last_logits = output.scores[-1].squeeze(0)
token_Y = self._backend._tokenizer("Y", add_special_tokens=False)[
"input_ids"
][0]
token_N = self._backend._tokenizer("N", add_special_tokens=False)[
"input_ids"
][0]
logit_Y = last_logits[token_Y].item()
logit_N = last_logits[token_N].item()
return "Y" if logit_Y > logit_N else "N"
else:
output = self._backend.alora_model.generate(
input_combined["input_ids"].to(self._backend._device),
attention_mask=input_combined["attention_mask"].to(
self._backend._device
),
max_new_tokens=1,
return_dict_in_generate=True,
alora_offsets=alora_offsets,
)
output_text = self._backend._tokenizer.decode(output.sequences[0])
constraint_satisfied = output_text.split(self._generation_prompt)[-1]
return constraint_satisfied[0]
return input_combined


async def processing(
mot: ModelOutputThunk,
chunk: GenerateDecoderOnlyOutput,
backend: LocalHFBackend,
force_yn: bool,
gen_prompt: str,
):
if mot._underlying_value is None:
mot._underlying_value = ""

# Don't support async for HFConstraintAlora. Means we can process the output here.
assert isinstance(chunk, GenerateDecoderOnlyOutput)

if force_yn:
last_logits = chunk.scores[-1].squeeze(0) # type: ignore
token_Y = backend._tokenizer("Y", add_special_tokens=False)["input_ids"][0] # type: ignore
token_N = backend._tokenizer("N", add_special_tokens=False)["input_ids"][0] # type: ignore
logit_Y = last_logits[token_Y].item()
logit_N = last_logits[token_N].item()
mot._underlying_value = "Y" if logit_Y > logit_N else "N"
else:
output_text = backend._tokenizer.decode(chunk.sequences[0])
constraint_satisfied = output_text.split(gen_prompt)[-1]
mot._underlying_value = constraint_satisfied[
0
] # Grab the first char of the str.


async def post_processing(mot: ModelOutputThunk, backend: LocalHFBackend):
backend.formatter.parse(mot._action, mot) # type: ignore


def add_granite_aloras(backend: LocalHFBackend):
Expand Down
Loading
Loading