Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion beir/retrieval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class EvaluateRetrieval:

def __init__(self, retriever: Union[Type[DRES], Type[DRFS], Type[BM25], Type[SS]] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
def __init__(self, retriever: Union[DRES, DRFS, BM25, SS] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
self.k_values = k_values
self.top_k = max(k_values)
self.retriever = retriever
Expand Down
12 changes: 6 additions & 6 deletions beir/retrieval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

class TrainRetriever:

def __init__(self, model: Type[SentenceTransformer], batch_size: int = 64):
def __init__(self, model: SentenceTransformer, batch_size: int = 64):
self.model = model
self.batch_size = batch_size

def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
qrels: Dict[str, Dict[str, int]]) -> List[Type[InputExample]]:
qrels: Dict[str, Dict[str, int]]) -> List[InputExample]:

query_ids = list(queries.keys())
train_samples = []
Expand All @@ -40,7 +40,7 @@ def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str],
logger.info("Loaded {} training pairs.".format(len(train_samples)))
return train_samples

def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type[InputExample]]:
def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[InputExample]:

train_samples = []

Expand All @@ -53,15 +53,15 @@ def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type
logger.info("Loaded {} training pairs.".format(len(train_samples)))
return train_samples

def prepare_train(self, train_dataset: List[Type[InputExample]], shuffle: bool = True, dataset_present: bool = False) -> DataLoader:
def prepare_train(self, train_dataset: List[InputExample], shuffle: bool = True, dataset_present: bool = False) -> DataLoader:

if not dataset_present:
train_dataset = SentencesDataset(train_dataset, model=self.model)

train_dataloader = DataLoader(train_dataset, shuffle=shuffle, batch_size=self.batch_size)
return train_dataloader

def prepare_train_triplets(self, train_dataset: List[Type[InputExample]]) -> DataLoader:
def prepare_train_triplets(self, train_dataset: List[InputExample]) -> DataLoader:

train_dataloader = datasets.NoDuplicatesDataLoader(train_dataset, batch_size=self.batch_size)
return train_dataloader
Expand Down Expand Up @@ -117,7 +117,7 @@ def fit(self,
steps_per_epoch = None,
scheduler: str = 'WarmupLinear',
warmup_steps: int = 10000,
optimizer_class: Type[Optimizer] = AdamW,
optimizer_class: Optimizer = AdamW,
optimizer_params : Dict[str, object]= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
weight_decay: float = 0.01,
evaluation_steps: int = 0,
Expand Down