Skip to content

Commit f9be71b

Browse files
authored
Fix rag (#38585)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 9eac19e commit f9be71b

File tree

1 file changed

+61
-18
lines changed

1 file changed

+61
-18
lines changed

tests/models/rag/test_modeling_rag.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest.mock import patch
2222

2323
import numpy as np
24+
import requests
2425

2526
from transformers import BartTokenizer, T5Tokenizer
2627
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
@@ -49,7 +50,7 @@
4950
if is_torch_available() and is_datasets_available() and is_faiss_available():
5051
import faiss
5152
import torch
52-
from datasets import Dataset
53+
from datasets import Dataset, load_dataset
5354

5455
from transformers import (
5556
AutoConfig,
@@ -679,6 +680,24 @@ def config_and_inputs(self):
679680
@require_tokenizers
680681
@require_torch_non_multi_accelerator
681682
class RagModelIntegrationTests(unittest.TestCase):
683+
@classmethod
684+
def setUpClass(cls):
685+
cls.temp_dir = tempfile.TemporaryDirectory()
686+
cls.dataset_path = cls.temp_dir.name
687+
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")
688+
689+
ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
690+
ds.save_to_disk(cls.dataset_path)
691+
692+
url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
693+
response = requests.get(url, stream=True)
694+
with open(cls.index_path, "wb") as fp:
695+
fp.write(response.content)
696+
697+
@classmethod
698+
def tearDownClass(cls):
699+
cls.temp_dir.cleanup()
700+
682701
def tearDown(self):
683702
super().tearDown()
684703
# clean-up as much as possible GPU memory occupied by PyTorch
@@ -722,8 +741,9 @@ def get_rag_config(self):
722741
max_combined_length=300,
723742
dataset="wiki_dpr",
724743
dataset_split="train",
725-
index_name="exact",
726-
index_path=None,
744+
index_name="custom",
745+
passages_path=self.dataset_path,
746+
index_path=self.index_path,
727747
use_dummy_dataset=True,
728748
retrieval_vector_size=768,
729749
retrieval_batch_size=8,
@@ -841,8 +861,8 @@ def test_rag_token_generate_beam(self):
841861
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
842862

843863
# Expected outputs as given by model at integration time.
844-
EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
845-
EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
864+
EXPECTED_OUTPUT_TEXT_1 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Uses"'
865+
EXPECTED_OUTPUT_TEXT_2 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Ways"'
846866

847867
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
848868
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@@ -903,7 +923,10 @@ def test_data_questions(self):
903923
def test_rag_sequence_generate_batch(self):
904924
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
905925
retriever = RagRetriever.from_pretrained(
906-
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
926+
"facebook/rag-sequence-nq",
927+
index_name="custom",
928+
passages_path=self.dataset_path,
929+
index_path=self.index_path,
907930
)
908931
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
909932
torch_device
@@ -926,12 +949,13 @@ def test_rag_sequence_generate_batch(self):
926949

927950
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
928951

952+
# PR #31938 cause the output being changed from `june 22, 2018` to `june 22 , 2018`.
929953
EXPECTED_OUTPUTS = [
930954
" albert einstein",
931-
" june 22, 2018",
955+
" june 22 , 2018",
932956
" amplitude modulation",
933957
" tim besley ( chairman )",
934-
" june 20, 2018",
958+
" june 20 , 2018",
935959
" 1980",
936960
" 7.0",
937961
" 8",
@@ -943,9 +967,9 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
943967
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
944968
retriever = RagRetriever.from_pretrained(
945969
"facebook/rag-sequence-nq",
946-
index_name="exact",
947-
use_dummy_dataset=True,
948-
dataset_revision="b24a417",
970+
index_name="custom",
971+
passages_path=self.dataset_path,
972+
index_path=self.index_path,
949973
)
950974
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
951975
torch_device
@@ -981,10 +1005,10 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
9811005

9821006
EXPECTED_OUTPUTS = [
9831007
" albert einstein",
984-
" june 22, 2018",
1008+
" june 22 , 2018",
9851009
" amplitude modulation",
9861010
" tim besley ( chairman )",
987-
" june 20, 2018",
1011+
" june 20 , 2018",
9881012
" 1980",
9891013
" 7.0",
9901014
" 8",
@@ -995,7 +1019,7 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
9951019
def test_rag_token_generate_batch(self):
9961020
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
9971021
retriever = RagRetriever.from_pretrained(
998-
"facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
1022+
"facebook/rag-token-nq", index_name="custom", passages_path=self.dataset_path, index_path=self.index_path
9991023
)
10001024
rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
10011025
torch_device
@@ -1023,10 +1047,10 @@ def test_rag_token_generate_batch(self):
10231047

10241048
EXPECTED_OUTPUTS = [
10251049
" albert einstein",
1026-
" september 22, 2017",
1050+
" september 22 , 2017",
10271051
" amplitude modulation",
10281052
" stefan persson",
1029-
" april 20, 2018",
1053+
" april 20 , 2018",
10301054
" the 1970s",
10311055
" 7.1. 2",
10321056
" 13",
@@ -1037,6 +1061,24 @@ def test_rag_token_generate_batch(self):
10371061
@require_torch
10381062
@require_retrieval
10391063
class RagModelSaveLoadTests(unittest.TestCase):
1064+
@classmethod
1065+
def setUpClass(cls):
1066+
cls.temp_dir = tempfile.TemporaryDirectory()
1067+
cls.dataset_path = cls.temp_dir.name
1068+
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")
1069+
1070+
ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
1071+
ds.save_to_disk(cls.dataset_path)
1072+
1073+
url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
1074+
response = requests.get(url, stream=True)
1075+
with open(cls.index_path, "wb") as fp:
1076+
fp.write(response.content)
1077+
1078+
@classmethod
1079+
def tearDownClass(cls):
1080+
cls.temp_dir.cleanup()
1081+
10401082
def tearDown(self):
10411083
super().tearDown()
10421084
# clean-up as much as possible GPU memory occupied by PyTorch
@@ -1060,8 +1102,9 @@ def get_rag_config(self):
10601102
max_combined_length=300,
10611103
dataset="wiki_dpr",
10621104
dataset_split="train",
1063-
index_name="exact",
1064-
index_path=None,
1105+
index_name="custom",
1106+
passages_path=self.dataset_path,
1107+
index_path=self.index_path,
10651108
use_dummy_dataset=True,
10661109
retrieval_vector_size=768,
10671110
retrieval_batch_size=8,

0 commit comments

Comments
 (0)