Skip to content

Commit 966d3bc

Browse files
Support streaming cnn_dailymail dataset (#4188)
* Support streaming cnn_dailymail dataset * Refactor URLS * Fix dataset card
1 parent e90a7d4 commit 966d3bc

File tree

2 files changed

+41
-73
lines changed

2 files changed

+41
-73
lines changed

datasets/cnn_dailymail/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ source_datasets:
1616
task_categories:
1717
- summarization
1818
task_ids:
19-
- summarization-news-articles-summarization
19+
- news-articles-summarization
2020
paperswithcode_id: cnn-daily-mail-1
2121
pretty_name: CNN / Daily Mail
2222
---

datasets/cnn_dailymail/cnn_dailymail.py

Lines changed: 40 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
logger = datasets.logging.get_logger(__name__)
2626

2727

28+
_HOMEPAGE = "https://github.com/abisee/cnn-dailymail"
29+
2830
_DESCRIPTION = """\
2931
CNN/DailyMail non-anonymized summarization dataset.
3032
@@ -63,13 +65,11 @@
6365
"""
6466

6567
_DL_URLS = {
66-
# pylint: disable=line-too-long
6768
"cnn_stories": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ",
6869
"dm_stories": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs",
69-
"test_urls": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt",
70-
"train_urls": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt",
71-
"val_urls": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt",
72-
# pylint: enable=line-too-long
70+
"train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt",
71+
"validation": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt",
72+
"test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt",
7373
}
7474

7575
_HIGHLIGHTS = "highlights"
@@ -104,7 +104,7 @@ def __init__(self, **kwargs):
104104

105105
def _get_url_hashes(path):
106106
"""Get hashes of urls in file."""
107-
urls = _read_text_file(path)
107+
urls = _read_text_file_path(path)
108108

109109
def url_hash(u):
110110
h = hashlib.sha1()
@@ -115,47 +115,12 @@ def url_hash(u):
115115
h.update(u)
116116
return h.hexdigest()
117117

118-
return {url_hash(u): True for u in urls}
118+
return {url_hash(u) for u in urls}
119119

120120

121121
def _get_hash_from_path(p):
122122
"""Extract hash from path."""
123-
basename = os.path.basename(p)
124-
return basename[0 : basename.find(".story")]
125-
126-
127-
def _find_files(dl_paths, publisher, url_dict):
128-
"""Find files corresponding to urls."""
129-
if publisher == "cnn":
130-
top_dir = os.path.join(dl_paths["cnn_stories"], "cnn", "stories")
131-
elif publisher == "dm":
132-
top_dir = os.path.join(dl_paths["dm_stories"], "dailymail", "stories")
133-
else:
134-
logger.fatal("Unsupported publisher: %s", publisher)
135-
files = sorted(os.listdir(top_dir))
136-
137-
ret_files = []
138-
for p in files:
139-
if _get_hash_from_path(p) in url_dict:
140-
ret_files.append(os.path.join(top_dir, p))
141-
return ret_files
142-
143-
144-
def _subset_filenames(dl_paths, split):
145-
"""Get filenames for a particular split."""
146-
assert isinstance(dl_paths, dict), dl_paths
147-
# Get filenames for a split.
148-
if split == datasets.Split.TRAIN:
149-
urls = _get_url_hashes(dl_paths["train_urls"])
150-
elif split == datasets.Split.VALIDATION:
151-
urls = _get_url_hashes(dl_paths["val_urls"])
152-
elif split == datasets.Split.TEST:
153-
urls = _get_url_hashes(dl_paths["test_urls"])
154-
else:
155-
logger.fatal("Unsupported split: %s", split)
156-
cnn = _find_files(dl_paths, "cnn", urls)
157-
dm = _find_files(dl_paths, "dm", urls)
158-
return cnn + dm
123+
return os.path.splitext(os.path.basename(p))[0]
159124

160125

161126
DM_SINGLE_CLOSE_QUOTE = "\u2019" # unicode
@@ -164,14 +129,16 @@ def _subset_filenames(dl_paths, split):
164129
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', DM_SINGLE_CLOSE_QUOTE, DM_DOUBLE_CLOSE_QUOTE, ")"]
165130

166131

167-
def _read_text_file(text_file):
168-
lines = []
169-
with open(text_file, "r", encoding="utf-8") as f:
170-
for line in f:
171-
lines.append(line.strip())
132+
def _read_text_file_path(path):
133+
with open(path, "r", encoding="utf-8") as f:
134+
lines = [line.strip() for line in f]
172135
return lines
173136

174137

138+
def _read_text_file(file):
139+
return [line.decode("utf-8").strip() for line in file]
140+
141+
175142
def _get_art_abs(story_file, tfds_version):
176143
"""Get abstract (highlights) and article from a story file path."""
177144
# Based on https://github.com/abisee/cnn-dailymail/blob/master/
@@ -231,7 +198,6 @@ class CnnDailymail(datasets.GeneratorBasedBuilder):
231198
]
232199

233200
def _info(self):
234-
# Should return a datasets.DatasetInfo object
235201
return datasets.DatasetInfo(
236202
description=_DESCRIPTION,
237203
features=datasets.Features(
@@ -242,7 +208,7 @@ def _info(self):
242208
}
243209
),
244210
supervised_keys=None,
245-
homepage="https://github.com/abisee/cnn-dailymail",
211+
homepage=_HOMEPAGE,
246212
citation=_CITATION,
247213
)
248214

@@ -251,29 +217,31 @@ def _vocab_text_gen(self, paths):
251217
yield " ".join([ex[_ARTICLE], ex[_HIGHLIGHTS]])
252218

253219
def _split_generators(self, dl_manager):
254-
dl_paths = dl_manager.download_and_extract(_DL_URLS)
255-
train_files = _subset_filenames(dl_paths, datasets.Split.TRAIN)
256-
# Generate shared vocabulary
257-
220+
dl_paths = dl_manager.download(_DL_URLS)
258221
return [
259-
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": train_files}),
260-
datasets.SplitGenerator(
261-
name=datasets.Split.VALIDATION,
262-
gen_kwargs={"files": _subset_filenames(dl_paths, datasets.Split.VALIDATION)},
263-
),
264222
datasets.SplitGenerator(
265-
name=datasets.Split.TEST, gen_kwargs={"files": _subset_filenames(dl_paths, datasets.Split.TEST)}
266-
),
223+
name=split,
224+
gen_kwargs={
225+
"urls_file": dl_paths[split],
226+
"cnn_stories_archive": dl_manager.iter_archive(dl_paths["cnn_stories"]),
227+
"dm_stories_archive": dl_manager.iter_archive(dl_paths["dm_stories"]),
228+
},
229+
)
230+
for split in [datasets.Split.TRAIN, datasets.Split.VALIDATION, datasets.Split.TEST]
267231
]
268232

269-
def _generate_examples(self, files):
270-
for p in files:
271-
article, highlights = _get_art_abs(p, self.config.version)
272-
if not article or not highlights:
273-
continue
274-
fname = os.path.basename(p)
275-
yield fname, {
276-
_ARTICLE: article,
277-
_HIGHLIGHTS: highlights,
278-
"id": _get_hash_from_path(fname),
279-
}
233+
def _generate_examples(self, urls_file, cnn_stories_archive, dm_stories_archive):
234+
urls = _get_url_hashes(urls_file)
235+
idx = 0
236+
for path, file in cnn_stories_archive:
237+
hash_from_path = _get_hash_from_path(path)
238+
if hash_from_path in urls:
239+
article, highlights = _get_art_abs(file, self.config.version)
240+
if not article or not highlights:
241+
continue
242+
yield idx, {
243+
_ARTICLE: article,
244+
_HIGHLIGHTS: highlights,
245+
"id": hash_from_path,
246+
}
247+
idx += 1

0 commit comments

Comments
 (0)