Skip to content
Merged
20 changes: 11 additions & 9 deletions mteb/models/model_implementations/siglip_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from tqdm.auto import tqdm

from mteb._requires_package import requires_package
from mteb.models.abs_encoder import AbsEncoder
from mteb.models.model_meta import ModelMeta, ScoringFunction

Expand Down Expand Up @@ -34,12 +35,10 @@ def __init__(
):
from transformers import AutoModel, AutoProcessor

try:
import sentencepiece # noqa: F401
except ImportError:
raise ImportError(
"The `sentencepiece` package is required to run `pip install sentencepiece`"
)
# for sentencepiece and protobuf:
requires_package(
self, "sentencepiece", model_name, "pip install 'mteb[siglip]'"
)

self.model_name = model_name
self.device = device
Expand Down Expand Up @@ -68,7 +67,8 @@ def get_text_embeddings(
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_outputs = self.model.get_text_features(**inputs)
all_text_embeddings.append(text_outputs.cpu())
embeddings = text_outputs.pooler_output
all_text_embeddings.append(embeddings.cpu())

all_text_embeddings = torch.cat(all_text_embeddings, dim=0)
return all_text_embeddings
Expand All @@ -83,12 +83,14 @@ def get_image_embeddings(

with torch.no_grad():
for batch in tqdm(images):
_images = [img.convert("RGB") for img in batch["image"]]
inputs = self.processor(
images=batch["image"], return_tensors="pt", padding=True
images=_images, return_tensors="pt", padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
image_outputs = self.model.get_image_features(**inputs)
all_image_embeddings.append(image_outputs.cpu())
embeddings = image_outputs.pooler_output
all_image_embeddings.append(embeddings.cpu())
all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
return all_image_embeddings

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ wav2clip = ["wav2clip==0.1.0"]
torch-vggish-yamnet = ["torch-vggish-yamnet==0.2.1"]
vllm = ["vllm>=0.11.1"]
mctct = ["transformers<5"] # mctct was removed in transformers 5
siglip = ["sentencepiece>=0.2.0","protobuf>=3.0.0"]
qwen-vl = ["transformers>=4.57.0", "qwen-vl-utils>=0.0.14"]

[dependency-groups]
Expand Down
Loading
Loading