diff --git a/src/smolagents/models.py b/src/smolagents/models.py index bbffd7b66..f08e048e9 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -533,6 +533,7 @@ def to_dict(self) -> dict: "timeout", "api_base", "torch_dtype", + "dtype", "device_map", "organization", "project", @@ -783,8 +784,8 @@ class TransformersModel(Model): For example, `"Qwen/Qwen2.5-Coder-32B-Instruct"`. device_map (`str`, *optional*): The device_map to initialize your model with. - torch_dtype (`str`, *optional*): - The torch_dtype to initialize your model with. + dtype (`str`, *optional*): + The dtype to initialize your model with. trust_remote_code (bool, default `False`): Some models on the Hub require running remote code: for this model, you would have to set this flag to True. model_kwargs (`dict[str, Any]`, *optional*): @@ -817,7 +818,7 @@ def __init__( self, model_id: str | None = None, device_map: str | None = None, - torch_dtype: str | None = None, + dtype: str | None = None, trust_remote_code: bool = False, model_kwargs: dict[str, Any] | None = None, max_new_tokens: int = 4096, @@ -854,11 +855,16 @@ def __init__( logger.info(f"Using device: {device_map}") self._is_vlm = False self.model_kwargs = model_kwargs or {} + + # BC: previously the type was set through `torch_dtype`. `dtype` is now prefered + torch_dtype = kwargs.pop("torch_dtype", None) + dtype = dtype or torch_dtype + try: self.model = AutoModelForImageTextToText.from_pretrained( model_id, device_map=device_map, - torch_dtype=torch_dtype, + dtype=dtype, trust_remote_code=trust_remote_code, **self.model_kwargs, ) @@ -871,7 +877,7 @@ def __init__( self.model = AutoModelForCausalLM.from_pretrained( model_id, device_map=device_map, - torch_dtype=torch_dtype, + dtype=dtype, trust_remote_code=trust_remote_code, **self.model_kwargs, ) @@ -886,6 +892,8 @@ def __init__( ) def make_stopping_criteria(self, stop_sequences: list[str], tokenizer) -> "StoppingCriteriaList": + warnings.warn("`make_stopping_criteria` is deprecated, pass `stop_strings` directly to `generate` instead") + from transformers import StoppingCriteria, StoppingCriteriaList class StopOnStrings(StoppingCriteria): @@ -939,18 +947,14 @@ def _prepare_completion_args( return_dict=True, ) prompt_tensor = prompt_tensor.to(self.model.device) # type: ignore - if hasattr(prompt_tensor, "input_ids"): - prompt_tensor = prompt_tensor["input_ids"] model_tokenizer = self.processor.tokenizer if hasattr(self, "processor") else self.tokenizer - stopping_criteria = ( - self.make_stopping_criteria(stop_sequences, tokenizer=model_tokenizer) if stop_sequences else None - ) completion_kwargs["max_new_tokens"] = max_new_tokens return dict( - inputs=prompt_tensor, + **prompt_tensor, use_cache=True, - stopping_criteria=stopping_criteria, + stop_strings=stop_sequences, + tokenizer=model_tokenizer, **completion_kwargs, ) @@ -970,7 +974,7 @@ def generate( tools_to_call_from=tools_to_call_from, **kwargs, ) - count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore + count_prompt_tokens = generation_kwargs["input_ids"].shape[1] # type: ignore out = self.model.generate( **generation_kwargs, ) @@ -987,7 +991,11 @@ def generate( content=output_text, raw={ "out": output_text, - "completion_kwargs": {key: value for key, value in generation_kwargs.items() if key != "inputs"}, + "completion_kwargs": { + key: value + for key, value in generation_kwargs.items() + if key not in ("input_ids", "attention_mask") + }, }, token_usage=TokenUsage( input_tokens=count_prompt_tokens, @@ -1014,7 +1022,7 @@ def generate_stream( ) # Get prompt token count once - count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore + count_prompt_tokens = generation_kwargs["input_ids"].shape[1] # type: ignore # Start generation in a separate thread thread = Thread(target=self.model.generate, kwargs={"streamer": self.streamer, **generation_kwargs}) diff --git a/tests/test_models.py b/tests/test_models.py index 859d36e16..c63fc5451 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -545,7 +545,7 @@ def test_init(self, patching): assert model.model == mocks["transformers.AutoModelForCausalLM.from_pretrained"].return_value assert mocks["transformers.AutoModelForCausalLM.from_pretrained"].call_args.kwargs == { "device_map": "cpu", - "torch_dtype": "float16", + "dtype": "float16", "trust_remote_code": True, } assert model.tokenizer == mocks["transformers.AutoTokenizer.from_pretrained"].return_value @@ -555,7 +555,7 @@ def test_init(self, patching): assert model.model == mocks["transformers.AutoModelForImageTextToText.from_pretrained"].return_value assert mocks["transformers.AutoModelForImageTextToText.from_pretrained"].call_args.kwargs == { "device_map": "cpu", - "torch_dtype": "float16", + "dtype": "float16", "trust_remote_code": True, } assert model.processor == mocks["transformers.AutoProcessor.from_pretrained"].return_value