1212import functools
1313import inspect
1414import json
15+ import threading
1516from collections .abc import Callable , Coroutine
1617from copy import deepcopy
1718from typing import TYPE_CHECKING , Any , cast
@@ -182,6 +183,9 @@ def __init__(
182183 self ._added_adapters : dict [str , LocalHFAdapter ] = {}
183184 self ._loaded_adapters : dict [str , LocalHFAdapter ] = {}
184185
186+ self ._generation_lock = threading .Lock ()
187+ """Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""
188+
185189 async def generate_from_context (
186190 self ,
187191 action : Component | CBlock ,
@@ -245,12 +249,43 @@ async def generate_from_context(
245249 )
246250 return mot , ctx .add (action ).add (mot )
247251
252+ def _generate_with_adapter_lock (
253+ self , adapter_name : str , generate_func : Callable , * args , ** kwargs
254+ ):
255+ """Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
256+ with self ._generation_lock :
257+ if adapter_name != "" :
258+ self .load_adapter (adapter_name )
259+ self ._model .set_adapter (adapter_name )
260+ else :
261+ try :
262+ # `._model.disable_adapters()` doesn't seem to actually disable them or
263+ # remove them from the model's list of `.active_adapters()`.
264+ self ._model .set_adapter ([])
265+ except ValueError as e :
266+ # If no weights have been loaded, the model will raise a ValueError:
267+ # `ValueError("No adapter loaded. Please load an adapter first.")`
268+ if "No adapter loaded" in str (e ):
269+ pass
270+ else :
271+ raise e
272+
273+ _assert_correct_adapters (adapter_name , self ._model )
274+ out = generate_func (* args , ** kwargs )
275+ _assert_correct_adapters (adapter_name , self ._model )
276+ return out
277+
248278 async def _generate_from_intrinsic (
249279 self , action : Intrinsic , ctx : Context , * , model_options : dict [str , Any ]
250280 ) -> ModelOutputThunk :
251281 if not ctx .is_chat_context :
252282 raise Exception ("Does not yet support non-chat contexts." )
253283
284+ if len (model_options .items ()) > 0 :
285+ FancyLogger .get_logger ().info (
286+ "passing in model options when generating with an adapter; some model options may be overwritten / ignored"
287+ )
288+
254289 linearized_ctx = ctx .view_for_generation ()
255290 assert linearized_ctx is not None , (
256291 "If ctx.is_chat_context, then the context should be linearizable."
@@ -311,33 +346,33 @@ async def _generate_from_intrinsic(
311346 "messages" : conversation ,
312347 "extra_body" : {"documents" : docs },
313348 }
349+
350+ # Convert other parameters from Mellea proprietary format to standard format.
351+ for model_option in model_options :
352+ if model_option == ModelOption .TEMPERATURE :
353+ request_json ["temperature" ] = model_options [model_option ]
354+
314355 rewritten = rewriter .transform (request_json , ** action .intrinsic_kwargs )
315356
316357 # TODO: Handle caching here. granite_common doesn't tell us what changed,
317358 # so we will have to invalidate the cache on our side. This requires
318359 # us having specific caching for each Component/Message.
319360
320- self .load_adapter (adapter .qualified_name )
321-
322- # TODO: This modifies the underlying model. We should set a non-exclusive lock here.
323- # It should allow generate requests with the same adapter to proceed. This logic also
324- # needs to be added to the other generate functions.
325- self ._model .set_adapter (adapter .qualified_name )
326-
327361 generate_input , other_input = (
328362 granite_common .util .chat_completion_request_to_transformers_inputs (
329363 rewritten , self ._tokenizer , self ._model
330364 )
331365 )
332366
333- chat_response : Coroutine [Any , Any , granite_common .ChatCompletionResponse ] = (
334- asyncio .to_thread (
335- granite_common .util .generate_with_transformers ,
336- self ._tokenizer ,
337- self ._model ,
338- generate_input ,
339- other_input ,
340- )
367+ chat_response = asyncio .to_thread (
368+ self ._generate_with_adapter_lock ,
369+ adapter .qualified_name ,
370+ granite_common .util .generate_with_transformers ,
371+ # Passed as args/kwargs to generate.
372+ self ._tokenizer ,
373+ self ._model ,
374+ generate_input ,
375+ other_input ,
341376 )
342377
343378 output = ModelOutputThunk (None )
@@ -490,7 +525,10 @@ async def _generate_from_context_standard(
490525 generate_options = self ._filter_chat_template_only_options (model_options )
491526
492527 chat_response = asyncio .to_thread (
528+ self ._generate_with_adapter_lock ,
529+ "" , # Empty for no adapters.
493530 self ._model .generate , # type: ignore
531+ # Passed as args/kwargs to generate.
494532 input_ids ,
495533 return_dict_in_generate = True ,
496534 output_scores = True ,
@@ -664,42 +702,41 @@ async def generate_from_raw(
664702 self ._device
665703 )
666704
667- if format is None :
668- outputs = await asyncio .to_thread (
669- self ._model .generate , # type: ignore
670- input_ids = inputs ["input_ids" ],
671- attention_mask = inputs ["attention_mask" ],
672- return_dict_in_generate = True ,
673- output_scores = True ,
674- ** self ._make_backend_specific_and_remove (model_opts ),
675- )
676- else :
705+ format_kwargs = {}
706+ if format :
707+ # outlines.generate.json always parses the resulting json into a python dict.
708+ # We however want to keep it as a json string for later storing it in ModelOutputThunk
677709 schema : dict [str , Any ] = format .model_json_schema ()
678710 schema_json : str = json .dumps (schema )
679- regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema (
711+ regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema ( # type: ignore
680712 schema_json
681713 )
682714
683715 from outlines .models .transformers import TransformerTokenizer
684- from outlines .processors import RegexLogitsProcessor
716+ from outlines .processors . structured import RegexLogitsProcessor
685717 from transformers import LogitsProcessorList
686718
687- outputs = await asyncio .to_thread (
688- self ._model .generate , # type: ignore
689- input_ids = inputs ["input_ids" ],
690- attention_mask = inputs ["attention_mask" ],
691- return_dict_in_generate = True ,
692- output_scores = True ,
693- logits_processor = LogitsProcessorList (
694- [
695- RegexLogitsProcessor (
696- regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
697- )
698- ]
699- ),
700- ** self ._make_backend_specific_and_remove (model_opts ),
719+ format_kwargs ["logits_processor" ] = LogitsProcessorList (
720+ [
721+ RegexLogitsProcessor (
722+ regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
723+ )
724+ ]
701725 )
702726
727+ outputs = await asyncio .to_thread (
728+ self ._generate_with_adapter_lock ,
729+ "" , # Empty for no adapter.
730+ self ._model .generate , # type: ignore
731+ # Passed as args/kwargs to generate.
732+ input_ids = inputs ["input_ids" ],
733+ attention_mask = inputs ["attention_mask" ],
734+ return_dict_in_generate = True ,
735+ output_scores = True ,
736+ ** self ._make_backend_specific_and_remove (model_opts ),
737+ ** format_kwargs ,
738+ )
739+
703740 sequences_to_decode = [
704741 sequence [inputs ["input_ids" ][i ].size (0 ) :] # type: ignore
705742 for i , sequence in enumerate (outputs .sequences )
@@ -853,7 +890,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
853890 self ._added_adapters [adapter .qualified_name ] = adapter
854891
855892 def load_adapter (self , adapter_qualified_name : str ):
856- """Loads the given adapter for the backend. Must have previously been added."""
893+ """Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening. """
857894 adapter = self ._added_adapters .get (adapter_qualified_name , None )
858895 if adapter is None :
859896 raise ValueError (
@@ -880,7 +917,7 @@ def load_adapter(self, adapter_qualified_name: str):
880917 # Loading an adapter activates it. We disable adapters immediately after.
881918 # Prefer this over `.disable_adapters()`; the disable function doesn't always
882919 # seem to work.
883- self ._model .set_adapter ([] )
920+ self ._model .disable_adapters ( )
884921 self ._loaded_adapters [adapter .qualified_name ] = adapter
885922
886923 def unload_adapter (self , adapter_qualified_name : str ):
@@ -906,6 +943,38 @@ def list_adapters(self) -> list[str]:
906943 return list (self ._loaded_adapters .keys ())
907944
908945
946+ def _assert_correct_adapters (expected_state : str , model : PreTrainedModel ):
947+ """When generating with a huggingface model, this can be used to ensure the correct adapters are active.
948+
949+ Args:
950+ expected_state: the current state of the lock
951+ model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
952+ """
953+ try :
954+ active = model .active_adapters ()
955+
956+ if expected_state == "" :
957+ assert len (active ) == 0 , (
958+ f'no adapters should be active if expected state is "", got "{ active [0 ]} "'
959+ )
960+ else :
961+ assert len (active ) == 1 , (
962+ f'one adapter should be active if expected state is "{ expected_state } "'
963+ )
964+ assert active [0 ] == expected_state , (
965+ f'the active adapter "{ active [0 ]} " doesn\' t match the expected state: "{ expected_state } "'
966+ )
967+ except ValueError as e :
968+ # If no weights have been loaded, the model will raise a ValueError:
969+ # `ValueError("No adapter loaded. Please load an adapter first.")`
970+ if "No adapter loaded" in str (e ):
971+ assert expected_state == "" , (
972+ f'got no adapters loaded but expected state is "{ expected_state } "'
973+ )
974+ else :
975+ raise e
976+
977+
909978class HFProcessRewardModel (PRM , abc .ABC ):
910979 """A Process Reward Model that works with a huggingface backend."""
911980
0 commit comments