Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def to_dict(self) -> dict:
"timeout",
"api_base",
"torch_dtype",
"dtype",
"device_map",
"organization",
"project",
Expand Down Expand Up @@ -783,8 +784,8 @@ class TransformersModel(Model):
For example, `"Qwen/Qwen2.5-Coder-32B-Instruct"`.
device_map (`str`, *optional*):
The device_map to initialize your model with.
torch_dtype (`str`, *optional*):
The torch_dtype to initialize your model with.
dtype (`str`, *optional*):
The dtype to initialize your model with.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a breaking change you introduced in transformers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes -- we'll guarantee BC until v5.0.0 I believe

trust_remote_code (bool, default `False`):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
model_kwargs (`dict[str, Any]`, *optional*):
Expand Down Expand Up @@ -817,7 +818,7 @@ def __init__(
self,
model_id: str | None = None,
device_map: str | None = None,
torch_dtype: str | None = None,
dtype: str | None = None,
trust_remote_code: bool = False,
model_kwargs: dict[str, Any] | None = None,
max_new_tokens: int = 4096,
Expand Down Expand Up @@ -854,11 +855,16 @@ def __init__(
logger.info(f"Using device: {device_map}")
self._is_vlm = False
self.model_kwargs = model_kwargs or {}

# BC: previously the type was set through `torch_dtype`. `dtype` is now prefered
torch_dtype = kwargs.pop("torch_dtype", None)
dtype = dtype or torch_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see the explanation here.

Copy link
Member

@albertvillanova albertvillanova Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should emit a deprecation warning for smolagents users? What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, this makes some CI tests fail:

TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'dtype'


try:
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map=device_map,
torch_dtype=torch_dtype,
dtype=dtype,
trust_remote_code=trust_remote_code,
**self.model_kwargs,
)
Expand All @@ -871,7 +877,7 @@ def __init__(
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
torch_dtype=torch_dtype,
dtype=dtype,
trust_remote_code=trust_remote_code,
**self.model_kwargs,
)
Expand All @@ -886,6 +892,8 @@ def __init__(
)

def make_stopping_criteria(self, stop_sequences: list[str], tokenizer) -> "StoppingCriteriaList":
warnings.warn("`make_stopping_criteria` is deprecated, pass `stop_strings` directly to `generate` instead")

from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnStrings(StoppingCriteria):
Expand Down Expand Up @@ -939,18 +947,14 @@ def _prepare_completion_args(
return_dict=True,
)
prompt_tensor = prompt_tensor.to(self.model.device) # type: ignore
if hasattr(prompt_tensor, "input_ids"):
prompt_tensor = prompt_tensor["input_ids"]

model_tokenizer = self.processor.tokenizer if hasattr(self, "processor") else self.tokenizer
stopping_criteria = (
self.make_stopping_criteria(stop_sequences, tokenizer=model_tokenizer) if stop_sequences else None
)
completion_kwargs["max_new_tokens"] = max_new_tokens
return dict(
inputs=prompt_tensor,
**prompt_tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure of understanding this change... 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputs were being prepared such that prompt_tensor only contained the input_ids. However, depending on the models and usage, the corresponding attention_mask (also returned by the tokenizer) may also be needed for a correct output. While using smolagents with transformers models, we could see a related warning being thrown :)

These changes make it so we pass all tokenizer encoding outputs (input_ids AND attention_mask) to model.generate, and thus guarantee correctness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the clear explanation! 🤗

use_cache=True,
stopping_criteria=stopping_criteria,
stop_strings=stop_sequences,
tokenizer=model_tokenizer,
**completion_kwargs,
)

Expand All @@ -970,7 +974,7 @@ def generate(
tools_to_call_from=tools_to_call_from,
**kwargs,
)
count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore
count_prompt_tokens = generation_kwargs["input_ids"].shape[1] # type: ignore
out = self.model.generate(
**generation_kwargs,
)
Expand All @@ -987,7 +991,11 @@ def generate(
content=output_text,
raw={
"out": output_text,
"completion_kwargs": {key: value for key, value in generation_kwargs.items() if key != "inputs"},
"completion_kwargs": {
key: value
for key, value in generation_kwargs.items()
if key not in ("input_ids", "attention_mask")
},
},
token_usage=TokenUsage(
input_tokens=count_prompt_tokens,
Expand All @@ -1014,7 +1022,7 @@ def generate_stream(
)

# Get prompt token count once
count_prompt_tokens = generation_kwargs["inputs"].shape[1] # type: ignore
count_prompt_tokens = generation_kwargs["input_ids"].shape[1] # type: ignore

# Start generation in a separate thread
thread = Thread(target=self.model.generate, kwargs={"streamer": self.streamer, **generation_kwargs})
Expand Down
Loading