-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Use stopping criteria from transformers (and other minor transformer fixes) #1723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -533,6 +533,7 @@ def to_dict(self) -> dict: | |
"timeout", | ||
"api_base", | ||
"torch_dtype", | ||
"dtype", | ||
"device_map", | ||
"organization", | ||
"project", | ||
|
@@ -783,8 +784,8 @@ class TransformersModel(Model): | |
For example, `"Qwen/Qwen2.5-Coder-32B-Instruct"`. | ||
device_map (`str`, *optional*): | ||
The device_map to initialize your model with. | ||
torch_dtype (`str`, *optional*): | ||
The torch_dtype to initialize your model with. | ||
dtype (`str`, *optional*): | ||
The dtype to initialize your model with. | ||
trust_remote_code (bool, default `False`): | ||
Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||
model_kwargs (`dict[str, Any]`, *optional*): | ||
|
@@ -817,7 +818,7 @@ def __init__( | |
self, | ||
model_id: str | None = None, | ||
device_map: str | None = None, | ||
torch_dtype: str | None = None, | ||
dtype: str | None = None, | ||
trust_remote_code: bool = False, | ||
model_kwargs: dict[str, Any] | None = None, | ||
max_new_tokens: int = 4096, | ||
|
@@ -854,11 +855,16 @@ def __init__( | |
logger.info(f"Using device: {device_map}") | ||
self._is_vlm = False | ||
self.model_kwargs = model_kwargs or {} | ||
|
||
# BC: previously the type was set through `torch_dtype`. `dtype` is now prefered | ||
torch_dtype = kwargs.pop("torch_dtype", None) | ||
dtype = dtype or torch_dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see the explanation here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should emit a deprecation warning for smolagents users? What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally, this makes some CI tests fail: TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'dtype' |
||
|
||
try: | ||
self.model = AutoModelForImageTextToText.from_pretrained( | ||
model_id, | ||
device_map=device_map, | ||
torch_dtype=torch_dtype, | ||
dtype=dtype, | ||
trust_remote_code=trust_remote_code, | ||
**self.model_kwargs, | ||
) | ||
|
@@ -871,7 +877,7 @@ def __init__( | |
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
device_map=device_map, | ||
torch_dtype=torch_dtype, | ||
dtype=dtype, | ||
trust_remote_code=trust_remote_code, | ||
**self.model_kwargs, | ||
) | ||
|
@@ -886,6 +892,8 @@ def __init__( | |
) | ||
|
||
def make_stopping_criteria(self, stop_sequences: list[str], tokenizer) -> "StoppingCriteriaList": | ||
warnings.warn("`make_stopping_criteria` is deprecated, pass `stop_strings` directly to `generate` instead") | ||
|
||
from transformers import StoppingCriteria, StoppingCriteriaList | ||
|
||
class StopOnStrings(StoppingCriteria): | ||
|
@@ -939,18 +947,14 @@ def _prepare_completion_args( | |
return_dict=True, | ||
) | ||
prompt_tensor = prompt_tensor.to(self.model.device) # type: ignore | ||
if hasattr(prompt_tensor, "input_ids"): | ||
prompt_tensor = prompt_tensor["input_ids"] | ||
|
||
model_tokenizer = self.processor.tokenizer if hasattr(self, "processor") else self.tokenizer | ||
stopping_criteria = ( | ||
self.make_stopping_criteria(stop_sequences, tokenizer=model_tokenizer) if stop_sequences else None | ||
) | ||
completion_kwargs["max_new_tokens"] = max_new_tokens | ||
return dict( | ||
inputs=prompt_tensor, | ||
**prompt_tensor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure of understanding this change... 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The inputs were being prepared such that These changes make it so we pass all tokenizer encoding outputs ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot for the clear explanation! 🤗 |
||
use_cache=True, | ||
stopping_criteria=stopping_criteria, | ||
stop_strings=stop_sequences, | ||
tokenizer=model_tokenizer, | ||
**completion_kwargs, | ||
) | ||
|
||
|
@@ -970,7 +974,7 @@ def generate( | |
tools_to_call_from=tools_to_call_from, | ||
**kwargs, | ||
) | ||
count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore | ||
count_prompt_tokens = generation_kwargs["input_ids"].shape[1] # type: ignore | ||
out = self.model.generate( | ||
**generation_kwargs, | ||
) | ||
|
@@ -987,7 +991,11 @@ def generate( | |
content=output_text, | ||
raw={ | ||
"out": output_text, | ||
"completion_kwargs": {key: value for key, value in generation_kwargs.items() if key != "inputs"}, | ||
"completion_kwargs": { | ||
key: value | ||
for key, value in generation_kwargs.items() | ||
if key not in ("input_ids", "attention_mask") | ||
}, | ||
}, | ||
token_usage=TokenUsage( | ||
input_tokens=count_prompt_tokens, | ||
|
@@ -1014,7 +1022,7 @@ def generate_stream( | |
) | ||
|
||
# Get prompt token count once | ||
count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore | ||
count_prompt_tokens = generation_kwargs["input_ids"].shape[1] # type: ignore | ||
|
||
# Start generation in a separate thread | ||
thread = Thread(target=self.model.generate, kwargs={"streamer": self.streamer, **generation_kwargs}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a breaking change you introduced in transformers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes -- we'll guarantee BC until v5.0.0 I believe