2525logger = datasets .logging .get_logger (__name__ )
2626
2727
28+ _HOMEPAGE = "https://github.com/abisee/cnn-dailymail"
29+
2830_DESCRIPTION = """\
2931 CNN/DailyMail non-anonymized summarization dataset.
3032
6365"""
6466
6567_DL_URLS = {
66- # pylint: disable=line-too-long
6768 "cnn_stories" : "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ" ,
6869 "dm_stories" : "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs" ,
69- "test_urls" : "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt" ,
70- "train_urls" : "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt" ,
71- "val_urls" : "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt" ,
72- # pylint: enable=line-too-long
70+ "train" : "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt" ,
71+ "validation" : "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt" ,
72+ "test" : "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt" ,
7373}
7474
7575_HIGHLIGHTS = "highlights"
@@ -104,7 +104,7 @@ def __init__(self, **kwargs):
104104
105105def _get_url_hashes (path ):
106106 """Get hashes of urls in file."""
107- urls = _read_text_file (path )
107+ urls = _read_text_file_path (path )
108108
109109 def url_hash (u ):
110110 h = hashlib .sha1 ()
@@ -115,47 +115,12 @@ def url_hash(u):
115115 h .update (u )
116116 return h .hexdigest ()
117117
118- return {url_hash (u ): True for u in urls }
118+ return {url_hash (u ) for u in urls }
119119
120120
121121def _get_hash_from_path (p ):
122122 """Extract hash from path."""
123- basename = os .path .basename (p )
124- return basename [0 : basename .find (".story" )]
125-
126-
127- def _find_files (dl_paths , publisher , url_dict ):
128- """Find files corresponding to urls."""
129- if publisher == "cnn" :
130- top_dir = os .path .join (dl_paths ["cnn_stories" ], "cnn" , "stories" )
131- elif publisher == "dm" :
132- top_dir = os .path .join (dl_paths ["dm_stories" ], "dailymail" , "stories" )
133- else :
134- logger .fatal ("Unsupported publisher: %s" , publisher )
135- files = sorted (os .listdir (top_dir ))
136-
137- ret_files = []
138- for p in files :
139- if _get_hash_from_path (p ) in url_dict :
140- ret_files .append (os .path .join (top_dir , p ))
141- return ret_files
142-
143-
144- def _subset_filenames (dl_paths , split ):
145- """Get filenames for a particular split."""
146- assert isinstance (dl_paths , dict ), dl_paths
147- # Get filenames for a split.
148- if split == datasets .Split .TRAIN :
149- urls = _get_url_hashes (dl_paths ["train_urls" ])
150- elif split == datasets .Split .VALIDATION :
151- urls = _get_url_hashes (dl_paths ["val_urls" ])
152- elif split == datasets .Split .TEST :
153- urls = _get_url_hashes (dl_paths ["test_urls" ])
154- else :
155- logger .fatal ("Unsupported split: %s" , split )
156- cnn = _find_files (dl_paths , "cnn" , urls )
157- dm = _find_files (dl_paths , "dm" , urls )
158- return cnn + dm
123+ return os .path .splitext (os .path .basename (p ))[0 ]
159124
160125
161126DM_SINGLE_CLOSE_QUOTE = "\u2019 " # unicode
@@ -164,14 +129,16 @@ def _subset_filenames(dl_paths, split):
164129END_TOKENS = ["." , "!" , "?" , "..." , "'" , "`" , '"' , DM_SINGLE_CLOSE_QUOTE , DM_DOUBLE_CLOSE_QUOTE , ")" ]
165130
166131
167- def _read_text_file (text_file ):
168- lines = []
169- with open (text_file , "r" , encoding = "utf-8" ) as f :
170- for line in f :
171- lines .append (line .strip ())
132+ def _read_text_file_path (path ):
133+ with open (path , "r" , encoding = "utf-8" ) as f :
134+ lines = [line .strip () for line in f ]
172135 return lines
173136
174137
138+ def _read_text_file (file ):
139+ return [line .decode ("utf-8" ).strip () for line in file ]
140+
141+
175142def _get_art_abs (story_file , tfds_version ):
176143 """Get abstract (highlights) and article from a story file path."""
177144 # Based on https://github.com/abisee/cnn-dailymail/blob/master/
@@ -231,7 +198,6 @@ class CnnDailymail(datasets.GeneratorBasedBuilder):
231198 ]
232199
233200 def _info (self ):
234- # Should return a datasets.DatasetInfo object
235201 return datasets .DatasetInfo (
236202 description = _DESCRIPTION ,
237203 features = datasets .Features (
@@ -242,7 +208,7 @@ def _info(self):
242208 }
243209 ),
244210 supervised_keys = None ,
245- homepage = "https://github.com/abisee/cnn-dailymail" ,
211+ homepage = _HOMEPAGE ,
246212 citation = _CITATION ,
247213 )
248214
@@ -251,29 +217,31 @@ def _vocab_text_gen(self, paths):
251217 yield " " .join ([ex [_ARTICLE ], ex [_HIGHLIGHTS ]])
252218
253219 def _split_generators (self , dl_manager ):
254- dl_paths = dl_manager .download_and_extract (_DL_URLS )
255- train_files = _subset_filenames (dl_paths , datasets .Split .TRAIN )
256- # Generate shared vocabulary
257-
220+ dl_paths = dl_manager .download (_DL_URLS )
258221 return [
259- datasets .SplitGenerator (name = datasets .Split .TRAIN , gen_kwargs = {"files" : train_files }),
260- datasets .SplitGenerator (
261- name = datasets .Split .VALIDATION ,
262- gen_kwargs = {"files" : _subset_filenames (dl_paths , datasets .Split .VALIDATION )},
263- ),
264222 datasets .SplitGenerator (
265- name = datasets .Split .TEST , gen_kwargs = {"files" : _subset_filenames (dl_paths , datasets .Split .TEST )}
266- ),
223+ name = split ,
224+ gen_kwargs = {
225+ "urls_file" : dl_paths [split ],
226+ "cnn_stories_archive" : dl_manager .iter_archive (dl_paths ["cnn_stories" ]),
227+ "dm_stories_archive" : dl_manager .iter_archive (dl_paths ["dm_stories" ]),
228+ },
229+ )
230+ for split in [datasets .Split .TRAIN , datasets .Split .VALIDATION , datasets .Split .TEST ]
267231 ]
268232
269- def _generate_examples (self , files ):
270- for p in files :
271- article , highlights = _get_art_abs (p , self .config .version )
272- if not article or not highlights :
273- continue
274- fname = os .path .basename (p )
275- yield fname , {
276- _ARTICLE : article ,
277- _HIGHLIGHTS : highlights ,
278- "id" : _get_hash_from_path (fname ),
279- }
233+ def _generate_examples (self , urls_file , cnn_stories_archive , dm_stories_archive ):
234+ urls = _get_url_hashes (urls_file )
235+ idx = 0
236+ for path , file in cnn_stories_archive :
237+ hash_from_path = _get_hash_from_path (path )
238+ if hash_from_path in urls :
239+ article , highlights = _get_art_abs (file , self .config .version )
240+ if not article or not highlights :
241+ continue
242+ yield idx , {
243+ _ARTICLE : article ,
244+ _HIGHLIGHTS : highlights ,
245+ "id" : hash_from_path ,
246+ }
247+ idx += 1
0 commit comments