diff --git a/python/cocoindex/functions/sbert.py b/python/cocoindex/functions/sbert.py index 93f9fc44..b4d8c3b0 100644 --- a/python/cocoindex/functions/sbert.py +++ b/python/cocoindex/functions/sbert.py @@ -1,7 +1,6 @@ """SentenceTransformer embedding functionality.""" -import dataclasses -from typing import Any, Literal +from typing import Any, Literal, cast import numpy as np from numpy.typing import NDArray @@ -60,7 +59,18 @@ def analyze(self) -> type: def __call__(self, text: list[str]) -> list[NDArray[np.float32]]: assert self._model is not None + + # Sort the text by length to minimize the number of padding tokens. + text_with_idx = [(idx, t) for idx, t in enumerate(text)] + text_with_idx.sort(key=lambda x: len(x[1])) + results: list[NDArray[np.float32]] = self._model.encode( - text, convert_to_numpy=True + [t for _, t in text_with_idx], convert_to_numpy=True ) - return results + final_results: list[NDArray[np.float32] | None] = [ + None for _ in range(len(text)) + ] + for (idx, _), result in zip(text_with_idx, results): + final_results[idx] = result + + return cast(list[NDArray[np.float32]], final_results)