Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions examples/benchmarking/using_beir_with_a_custom_sbert_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
"""Using BEIR with a Custom SBERT Model

Automatically generated by Colaboratory.

Original file is located at
https://colab.research.google.com/drive/1552z6RVGVaLgIlLy4AOXXGaweQgiEA4r

## Downloading Sentence Transformers
"""

!pip install sentence_transformers

"""Downloading the Semantic Textua Similarity dataset"""

!wget http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz

"""Taking a look at the files"""

# Commented out IPython magic to ensure Python compatibility.
# %%bash
#
# tar -xvf /content/Stsbenchmark.tar.gz

!cat stsbenchmark/readme.txt

"""Looking into the file contents"""

!head -n 10 stsbenchmark/sts-dev.csv

"""What is the average length of the train/dev/test files?"""

# Commented out IPython magic to ensure Python compatibility.
# %%bash
# dev_len=$(wc -l <stsbenchmark/sts-dev.csv)
# echo $dev_len
# train_len=$(wc -l <stsbenchmark/sts-train.csv)
# echo $train_len
# sum=$(($train_len+$dev_len))
# echo $((sum/2))

"""## Model building"""

from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import sys
import os
import gzip
import csv
import torch

sts_dataset_path = '/content/Stsbenchmark.tar.gz'
train_batch_size = 64
num_epochs = 20

model_name = 'distilbert-base-uncased' #Good starting point for a Transformer model

model_save_path = 'output/training_stsbenchmark_'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

print(model_save_path)

#Bi-Encoder with Bottleneck Layer

word_embedding_model = models.Transformer(model_name)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(),
out_features=768,
activation_function=torch.nn.GELU())
dense_model_inter = models.Dense(in_features=dense_model.get_sentence_embedding_dimension(),
out_features=384,
activation_function=torch.nn.GELU())
dense_output = models.Dense(in_features=dense_model_inter.get_sentence_embedding_dimension(),
out_features=768,
activation_function=torch.nn.GELU())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model,dense_model,dense_model_inter,dense_output])

# Simple model

word_embedding_model = models.Transformer(model_name)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(),
out_features=768,
activation_function=torch.nn.GELU())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model,dense_model])

# Properly loading the data
def create_data_samples(path:str,sample_list:list,encoding="utf-8"):
with open(path,encoding=encoding) as f:
reader=csv.reader(f,delimiter="\t",quoting=csv.QUOTE_NONE,dialect="excel")
for row in reader:
scores=float(row[4])/5.0 # Normalizing scores to 0-1 range from 1-5
sentence1=row[5] # Store the sentences for similarity evaluation
sentence2=row[6]
inp_example = InputExample(texts=[sentence1,sentence2], label=scores)
sample_list.append(inp_example)
return sample_list

train_samples=create_data_samples(path='/content/stsbenchmark/sts-train.csv',sample_list=[],encoding="utf-8")
test_samples=create_data_samples(path='/content/stsbenchmark/sts-test.csv',sample_list=[],encoding="utf-8")
dev_samples=create_data_samples(path='/content/stsbenchmark/sts-dev.csv',sample_list=[],encoding="utf-8")

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

"""Define the Loss function"""

train_loss = losses.CosineSimilarityLoss(model=model)

"""Define the evaluator"""

evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')

warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)

"""Time to train the model"""

model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path)

"""Testing the model performance on test set"""

model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
test_evaluator(model, output_path=model_save_path)

"""# Loading the pre-trained model

Lets us test the model on a few examples
"""

#
sentences1 = ['The cat sits outside',
'A man is playing guitar',
'The new movie is awesome']

sentences2 = ['The dog plays in the garden',
'A woman watches TV',
'The new movie is so great']

embeddings1 = model.encode(sentences1, convert_to_tensor=True)
embeddings2 = model.encode(sentences2, convert_to_tensor=True)

cosine_scores = util.cos_sim(embeddings1, embeddings2)
for i in range(len(sentences1)):
print("{} \t\t {} \t\t Score: {:.4f}".format(sentences1[i], sentences2[i], cosine_scores[i][i]))

!zip -r '/content/output.zip' '/content/output/'

"""# Benchmarking Zero Shot Performance using BEIR

"""

!pip install beir

!pip install tensorflow-text

"""Let us download Scifact dataset for testing Zero Shot performance of our model"""

from beir import util, LoggingHandler

import logging
import pathlib, os

import pathlib, os
from beir import util

dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))

"""Peeking into scifact folder files"""

!ls datasets/scifact/

from beir.datasets.data_loader import GenericDataLoader

data_path = "datasets/scifact"
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") # or split = "train" or "dev"

"""# Define the CustomModel"""

from typing import List, Dict
import numpy as np

class DistilBertModel:
def __init__(self, model_path=None, **kwargs):
self.model = SentenceTransformer(model_path)

# Write your own encoding query function (Returns: Query embeddings as numpy array)
# For eg ==> return np.asarray(self.model.encode(queries, batch_size=batch_size, **kwargs))
def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
return np.asarray(self.model.encode(queries, batch_size=batch_size, **kwargs).cpu())

# Write your own encoding corpus function (Returns: Document embeddings as numpy array)
# For eg ==> sentences = [(doc["title"] + " " + doc["text"]).strip() for doc in corpus]
# ==> return np.asarray(self.model.encode(sentences, batch_size=batch_size, **kwargs))
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> np.ndarray:
sentences = [(doc["title"] + " " + doc["text"]).strip() for doc in corpus]
return np.asarray(self.model.encode(sentences, batch_size=batch_size, **kwargs).cpu())

"""Finally evaluate the results on Scifact dataset using our custom-model"""

from beir.retrieval.evaluation import EvaluateRetrieval
#from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

model = DRES(DistilBertModel(model_path=model_save_path), batch_size=64)
retriever = EvaluateRetrieval(model, score_function="cos_sim")

#### Retrieve dense results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)

ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)

print(ndcg, _map, recall, precision)