2121from unittest .mock import patch
2222
2323import numpy as np
24+ import requests
2425
2526from transformers import BartTokenizer , T5Tokenizer
2627from transformers .models .bert .tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
4950if is_torch_available () and is_datasets_available () and is_faiss_available ():
5051 import faiss
5152 import torch
52- from datasets import Dataset
53+ from datasets import Dataset , load_dataset
5354
5455 from transformers import (
5556 AutoConfig ,
@@ -679,6 +680,24 @@ def config_and_inputs(self):
679680@require_tokenizers
680681@require_torch_non_multi_accelerator
681682class RagModelIntegrationTests (unittest .TestCase ):
683+ @classmethod
684+ def setUpClass (cls ):
685+ cls .temp_dir = tempfile .TemporaryDirectory ()
686+ cls .dataset_path = cls .temp_dir .name
687+ cls .index_path = os .path .join (cls .temp_dir .name , "index.faiss" )
688+
689+ ds = load_dataset ("hf-internal-testing/wiki_dpr_dummy" )["train" ]
690+ ds .save_to_disk (cls .dataset_path )
691+
692+ url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
693+ response = requests .get (url , stream = True )
694+ with open (cls .index_path , "wb" ) as fp :
695+ fp .write (response .content )
696+
697+ @classmethod
698+ def tearDownClass (cls ):
699+ cls .temp_dir .cleanup ()
700+
682701 def tearDown (self ):
683702 super ().tearDown ()
684703 # clean-up as much as possible GPU memory occupied by PyTorch
@@ -722,8 +741,9 @@ def get_rag_config(self):
722741 max_combined_length = 300 ,
723742 dataset = "wiki_dpr" ,
724743 dataset_split = "train" ,
725- index_name = "exact" ,
726- index_path = None ,
744+ index_name = "custom" ,
745+ passages_path = self .dataset_path ,
746+ index_path = self .index_path ,
727747 use_dummy_dataset = True ,
728748 retrieval_vector_size = 768 ,
729749 retrieval_batch_size = 8 ,
@@ -841,8 +861,8 @@ def test_rag_token_generate_beam(self):
841861 output_text_2 = rag_decoder_tokenizer .decode (output_ids [1 ], skip_special_tokens = True )
842862
843863 # Expected outputs as given by model at integration time.
844- EXPECTED_OUTPUT_TEXT_1 = " \" She's My Kind of Girl"
845- EXPECTED_OUTPUT_TEXT_2 = " \" She's My Kind of Love"
864+ EXPECTED_OUTPUT_TEXT_1 = '" She\ ' s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Uses"'
865+ EXPECTED_OUTPUT_TEXT_2 = '" She\ ' s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and " Love Has Its Ways"'
846866
847867 self .assertEqual (output_text_1 , EXPECTED_OUTPUT_TEXT_1 )
848868 self .assertEqual (output_text_2 , EXPECTED_OUTPUT_TEXT_2 )
@@ -903,7 +923,10 @@ def test_data_questions(self):
903923 def test_rag_sequence_generate_batch (self ):
904924 tokenizer = RagTokenizer .from_pretrained ("facebook/rag-sequence-nq" )
905925 retriever = RagRetriever .from_pretrained (
906- "facebook/rag-sequence-nq" , index_name = "exact" , use_dummy_dataset = True , dataset_revision = "b24a417"
926+ "facebook/rag-sequence-nq" ,
927+ index_name = "custom" ,
928+ passages_path = self .dataset_path ,
929+ index_path = self .index_path ,
907930 )
908931 rag_sequence = RagSequenceForGeneration .from_pretrained ("facebook/rag-sequence-nq" , retriever = retriever ).to (
909932 torch_device
@@ -926,12 +949,13 @@ def test_rag_sequence_generate_batch(self):
926949
927950 outputs = tokenizer .batch_decode (output_ids , skip_special_tokens = True )
928951
952+ # PR #31938 cause the output being changed from `june 22, 2018` to `june 22 , 2018`.
929953 EXPECTED_OUTPUTS = [
930954 " albert einstein" ,
931- " june 22, 2018" ,
955+ " june 22 , 2018" ,
932956 " amplitude modulation" ,
933957 " tim besley ( chairman )" ,
934- " june 20, 2018" ,
958+ " june 20 , 2018" ,
935959 " 1980" ,
936960 " 7.0" ,
937961 " 8" ,
@@ -943,9 +967,9 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
943967 tokenizer = RagTokenizer .from_pretrained ("facebook/rag-sequence-nq" )
944968 retriever = RagRetriever .from_pretrained (
945969 "facebook/rag-sequence-nq" ,
946- index_name = "exact " ,
947- use_dummy_dataset = True ,
948- dataset_revision = "b24a417" ,
970+ index_name = "custom " ,
971+ passages_path = self . dataset_path ,
972+ index_path = self . index_path ,
949973 )
950974 rag_sequence = RagSequenceForGeneration .from_pretrained ("facebook/rag-sequence-nq" , retriever = retriever ).to (
951975 torch_device
@@ -981,10 +1005,10 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
9811005
9821006 EXPECTED_OUTPUTS = [
9831007 " albert einstein" ,
984- " june 22, 2018" ,
1008+ " june 22 , 2018" ,
9851009 " amplitude modulation" ,
9861010 " tim besley ( chairman )" ,
987- " june 20, 2018" ,
1011+ " june 20 , 2018" ,
9881012 " 1980" ,
9891013 " 7.0" ,
9901014 " 8" ,
@@ -995,7 +1019,7 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
9951019 def test_rag_token_generate_batch (self ):
9961020 tokenizer = RagTokenizer .from_pretrained ("facebook/rag-token-nq" )
9971021 retriever = RagRetriever .from_pretrained (
998- "facebook/rag-token-nq" , index_name = "exact " , use_dummy_dataset = True , dataset_revision = "b24a417"
1022+ "facebook/rag-token-nq" , index_name = "custom " , passages_path = self . dataset_path , index_path = self . index_path
9991023 )
10001024 rag_token = RagTokenForGeneration .from_pretrained ("facebook/rag-token-nq" , retriever = retriever ).to (
10011025 torch_device
@@ -1023,10 +1047,10 @@ def test_rag_token_generate_batch(self):
10231047
10241048 EXPECTED_OUTPUTS = [
10251049 " albert einstein" ,
1026- " september 22, 2017" ,
1050+ " september 22 , 2017" ,
10271051 " amplitude modulation" ,
10281052 " stefan persson" ,
1029- " april 20, 2018" ,
1053+ " april 20 , 2018" ,
10301054 " the 1970s" ,
10311055 " 7.1. 2" ,
10321056 " 13" ,
@@ -1037,6 +1061,24 @@ def test_rag_token_generate_batch(self):
10371061@require_torch
10381062@require_retrieval
10391063class RagModelSaveLoadTests (unittest .TestCase ):
1064+ @classmethod
1065+ def setUpClass (cls ):
1066+ cls .temp_dir = tempfile .TemporaryDirectory ()
1067+ cls .dataset_path = cls .temp_dir .name
1068+ cls .index_path = os .path .join (cls .temp_dir .name , "index.faiss" )
1069+
1070+ ds = load_dataset ("hf-internal-testing/wiki_dpr_dummy" )["train" ]
1071+ ds .save_to_disk (cls .dataset_path )
1072+
1073+ url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
1074+ response = requests .get (url , stream = True )
1075+ with open (cls .index_path , "wb" ) as fp :
1076+ fp .write (response .content )
1077+
1078+ @classmethod
1079+ def tearDownClass (cls ):
1080+ cls .temp_dir .cleanup ()
1081+
10401082 def tearDown (self ):
10411083 super ().tearDown ()
10421084 # clean-up as much as possible GPU memory occupied by PyTorch
@@ -1060,8 +1102,9 @@ def get_rag_config(self):
10601102 max_combined_length = 300 ,
10611103 dataset = "wiki_dpr" ,
10621104 dataset_split = "train" ,
1063- index_name = "exact" ,
1064- index_path = None ,
1105+ index_name = "custom" ,
1106+ passages_path = self .dataset_path ,
1107+ index_path = self .index_path ,
10651108 use_dummy_dataset = True ,
10661109 retrieval_vector_size = 768 ,
10671110 retrieval_batch_size = 8 ,
0 commit comments