Skip to content
This repository was archived by the owner on Oct 15, 2025. It is now read-only.

Commit e45d509

Browse files
authored
fix: update sentence transformer udf (#854)
πŸ‘‹ Thanks for submitting a Pull Request to EvaDB! πŸ™Œ We want to make contributing to EvaDB as easy and transparent as possible. Here are a few tips to get you started: - πŸ” Search existing EvaDB [PRs](https://github.com/georgia-tech-db/eva/pulls) to see if a similar PR already exists. - πŸ”— Link this PR to a EvaDB [issue](https://github.com/georgia-tech-db/eva/issues) to help us understand what bug fix or feature is being implemented. - πŸ“ˆ Provide before and after profiling results to help us quantify the improvement your PR provides (if applicable). πŸ‘‰ Please see our βœ… [Contributing Guide](https://evadb.readthedocs.io/en/stable/source/contribute/index.html) for more details.
1 parent 95c501d commit e45d509

File tree

1 file changed

+8
-34
lines changed

1 file changed

+8
-34
lines changed

β€Ževadb/udfs/sentence_feature_extractor.pyβ€Ž

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
# limitations under the License.
1515
import numpy as np
1616
import pandas as pd
17-
import torch
18-
import torch.nn.functional as F
19-
from transformers import AutoModel, AutoTokenizer
17+
from sentence_transformers import SentenceTransformer
2018

2119
from evadb.catalog.catalog_type import NdArrayType
2220
from evadb.udfs.abstract.abstract_udf import AbstractUDF
@@ -25,30 +23,25 @@
2523
from evadb.udfs.gpu_compatible import GPUCompatible
2624

2725

28-
class SentenceFeatureExtractor(AbstractUDF, GPUCompatible):
26+
class SentenceTransformerFeatureExtractor(AbstractUDF, GPUCompatible):
2927
@setup(cacheable=False, udf_type="FeatureExtraction", batchable=False)
3028
def setup(self):
31-
self.tokenizer = AutoTokenizer.from_pretrained(
32-
"sentence-transformers/all-MiniLM-L6-v2"
33-
)
34-
self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
35-
self.model_device = None
29+
self.model = SentenceTransformer("all-MiniLM-L6-v2")
3630

3731
def to_device(self, device: str) -> GPUCompatible:
38-
self.model_device = device
3932
self.model = self.model.to(device)
4033
return self
4134

4235
@property
4336
def name(self) -> str:
44-
return "SentenceFeatureExtractor"
37+
return "SentenceTransformerFeatureExtractor"
4538

4639
@forward(
4740
input_signatures=[
4841
PandasDataframe(
4942
columns=["data"],
5043
column_types=[NdArrayType.STR],
51-
column_shapes=[(None, 1)],
44+
column_shapes=[(1)],
5245
)
5346
],
5447
output_signatures=[
@@ -61,28 +54,9 @@ def name(self) -> str:
6154
)
6255
def forward(self, df: pd.DataFrame) -> pd.DataFrame:
6356
def _forward(row: pd.Series) -> np.ndarray:
64-
sentence = row[0]
65-
66-
encoded_input = self.tokenizer(
67-
[sentence], padding=True, truncation=True, return_tensors="pt"
68-
)
69-
if self.model_device is not None:
70-
encoded_input.to(self.model_device)
71-
with torch.no_grad():
72-
model_output = self.model(**encoded_input)
73-
74-
attention_mask = encoded_input["attention_mask"]
75-
token_embedding = model_output[0]
76-
input_mask_expanded = (
77-
attention_mask.unsqueeze(-1).expand(token_embedding.size()).float()
78-
)
79-
sentence_embedding = torch.sum(
80-
token_embedding * input_mask_expanded, 1
81-
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
82-
sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)
83-
84-
sentence_embedding_np = sentence_embedding.cpu().numpy()
85-
return sentence_embedding_np
57+
data = row
58+
embedded_list = self.model.encode(data)
59+
return embedded_list
8660

8761
ret = pd.DataFrame()
8862
ret["features"] = df.apply(_forward, axis=1)

0 commit comments

Comments
Β (0)