Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 4439f94

Browse files
authored
[RAG] Handle TFIDF retriever with pre-trained model (#4436)
* handle tfidf retriever from init pre-trained model * needed this as well * move change to TGA
1 parent 6147606 commit 4439f94

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

parlai/agents/rag/retrievers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,12 +728,17 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None
728728
assert self.max_doc_paragraphs != 0
729729
if not shared:
730730
self.tfidf_retriever = create_agent(tfidf_opt)
731+
self.query_encoder = DprQueryEncoder(
732+
opt, dpr_model=opt['query_model'], pretrained_path=opt['dpr_model_file']
733+
)
731734
else:
732735
self.tfidf_retriever = shared['tfidf_retriever']
736+
self.query_encoder = shared['query_encoder']
733737

734738
def share(self) -> TShared:
735739
shared = super().share()
736740
shared['tfidf_retriever'] = self.tfidf_retriever
741+
shared['query_encoder'] = self.query_encoder
737742
return shared
738743

739744
def retrieve_and_score(

parlai/core/torch_generator_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def _add_generation_metrics(self, batch, preds):
820820
"""
821821
self.record_local_metric(
822822
'gen_n_toks',
823-
AverageMetric.many([p.size(0) for p in preds], [1] * batch.batchsize),
823+
AverageMetric.many([p.size(0) for p in preds], [1] * len(preds)),
824824
)
825825

826826
def rank_eval_label_candidates(self, batch, batchsize):

0 commit comments

Comments
 (0)