diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 360437bf..badeba75 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -447,7 +447,9 @@ def _generate_from_raw( prompts = [self.formatter.print(action) for action in actions] # batch-encoding call is deprecated in favor of this - inputs = self._tokenizer(prompts, return_tensors="pt").to(self._device) + inputs = self._tokenizer(prompts, return_tensors="pt", padding=True).to( + self._device + ) if format is None: outputs = self._model.generate( # type: ignore diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 6859099a..b7b7d15f 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -179,7 +179,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative def test_generate_from_raw(session): - prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] + prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?", "what is 4+2+2?"] results = session.backend._generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None