Skip to content

Commit 7eb0c7a

Browse files
authored
Host hf datasets and upgrade examples (#1910)
* hosting squad dataset * fix squad no answer metric * fix DataCollatorForSeq2Seq * hosting hf datasets * enhence import warning
1 parent f9fab94 commit 7eb0c7a

File tree

12 files changed

+1320
-6
lines changed

12 files changed

+1320
-6
lines changed

examples/language_model/bert/run_glue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import paddle
2626
from paddle.io import DataLoader
2727
from paddle.metric import Metric, Accuracy, Precision, Recall
28+
import paddlenlp
2829

2930
from datasets import load_dataset
3031
from paddlenlp.data import default_data_collator, DataCollatorWithPadding

examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def main():
124124
batched=True,
125125
remove_columns=column_names,
126126
num_proc=4)
127+
127128
batchify_fn = lambda samples, fn=Dict(
128129
{
129130
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),

examples/machine_reading_comprehension/SQuAD/run_squad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def prepare_validation_features(examples, tokenizer, args):
156156
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
157157
# position is part of the context or not.
158158
tokenized_examples["offset_mapping"][i] = [
159-
(o if sequence_ids[k] == context_index else None)
159+
(o if sequence_ids[k] == context_index and
160+
k != len(sequence_ids) - 1 else None)
160161
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
161162
]
162163

paddlenlp/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
if 'datasets' in sys.modules.keys():
1818
from paddlenlp.utils.log import logger
1919
logger.warning(
20-
"datasets module loaded before paddlenlp. "
21-
"This may cause PaddleNLP datasets to be unavalible in intranet.")
20+
"Detected that datasets module was imported before paddlenlp. "
21+
"This may cause PaddleNLP datasets to be unavalible in intranet"
22+
"Please import paddlenlp before datasets module to avoid download issues"
23+
)
2224
from . import data
2325
from . import datasets
2426
from . import embeddings

paddlenlp/data/collate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,4 +332,4 @@ def __call__(self, data):
332332
ret.extend(result)
333333
else:
334334
ret.append(result)
335-
return tuple(ret)
335+
return tuple(ret)

paddlenlp/data/data_collator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def __call__(self, data, return_tensors=None):
269269
if (labels is not None and self.model is not None and
270270
hasattr(self.model, "prepare_decoder_input_ids_from_labels")):
271271
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
272-
labels=batch["labels"])
272+
labels=paddle.to_tensor(batch["labels"]))
273273
if not return_tensors:
274274
batch["decoder_input_ids"] = decoder_input_ids.numpy()
275275
if self.return_tensors:
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# coding=utf-8
2+
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Lint as: python3
17+
"""CNN/DailyMail Summarization dataset, non-anonymized version."""
18+
19+
import hashlib
20+
import os
21+
22+
import datasets
23+
24+
logger = datasets.logging.get_logger(__name__)
25+
26+
_DESCRIPTION = """\
27+
CNN/DailyMail non-anonymized summarization dataset.
28+
29+
There are two features:
30+
- article: text of news article, used as the document to be summarized
31+
- highlights: joined text of highlights with <s> and </s> around each
32+
highlight, which is the target summary
33+
"""
34+
35+
# The second citation introduces the source data, while the first
36+
# introduces the specific form (non-anonymized) we use here.
37+
_CITATION = """\
38+
@article{DBLP:journals/corr/SeeLM17,
39+
author = {Abigail See and
40+
Peter J. Liu and
41+
Christopher D. Manning},
42+
title = {Get To The Point: Summarization with Pointer-Generator Networks},
43+
journal = {CoRR},
44+
volume = {abs/1704.04368},
45+
year = {2017},
46+
url = {http://arxiv.org/abs/1704.04368},
47+
archivePrefix = {arXiv},
48+
eprint = {1704.04368},
49+
timestamp = {Mon, 13 Aug 2018 16:46:08 +0200},
50+
biburl = {https://dblp.org/rec/bib/journals/corr/SeeLM17},
51+
bibsource = {dblp computer science bibliography, https://dblp.org}
52+
}
53+
54+
@inproceedings{hermann2015teaching,
55+
title={Teaching machines to read and comprehend},
56+
author={Hermann, Karl Moritz and Kocisky, Tomas and Grefenstette, Edward and Espeholt, Lasse and Kay, Will and Suleyman, Mustafa and Blunsom, Phil},
57+
booktitle={Advances in neural information processing systems},
58+
pages={1693--1701},
59+
year={2015}
60+
}
61+
"""
62+
63+
_DL_URLS = {
64+
# pylint: disable=line-too-long
65+
"cnn_stories":
66+
"https://bj.bcebos.com/paddlenlp/datasets/cnn_dailymail/cnn_stories.tgz",
67+
"dm_stories":
68+
"https://bj.bcebos.com/paddlenlp/datasets/cnn_dailymail/dailymail_stories.tgz",
69+
"test_urls":
70+
"https://bj.bcebos.com/paddlenlp/datasets/cnn_dailymail/all_test.txt",
71+
"train_urls":
72+
"https://bj.bcebos.com/paddlenlp/datasets/cnn_dailymail/all_train.txt",
73+
"val_urls":
74+
"https://bj.bcebos.com/paddlenlp/datasets/cnn_dailymail/all_val.txt",
75+
# pylint: enable=line-too-long
76+
}
77+
78+
_HIGHLIGHTS = "highlights"
79+
_ARTICLE = "article"
80+
81+
_SUPPORTED_VERSIONS = [
82+
# Using cased version.
83+
datasets.Version("3.0.0", "Using cased version."),
84+
# Same data as 0.0.2
85+
datasets.Version("1.0.0", ""),
86+
# Having the model predict newline separators makes it easier to evaluate
87+
# using summary-level ROUGE.
88+
datasets.Version("2.0.0", "Separate target sentences with newline."),
89+
]
90+
91+
_DEFAULT_VERSION = datasets.Version("3.0.0", "Using cased version.")
92+
93+
94+
class CnnDailymailConfig(datasets.BuilderConfig):
95+
"""BuilderConfig for CnnDailymail."""
96+
97+
def __init__(self, **kwargs):
98+
"""BuilderConfig for CnnDailymail.
99+
100+
Args:
101+
102+
**kwargs: keyword arguments forwarded to super.
103+
"""
104+
super(CnnDailymailConfig, self).__init__(**kwargs)
105+
106+
107+
def _get_url_hashes(path):
108+
"""Get hashes of urls in file."""
109+
urls = _read_text_file(path)
110+
111+
def url_hash(u):
112+
h = hashlib.sha1()
113+
try:
114+
u = u.encode("utf-8")
115+
except UnicodeDecodeError:
116+
logger.error("Cannot hash url: %s", u)
117+
h.update(u)
118+
return h.hexdigest()
119+
120+
return {url_hash(u): True for u in urls}
121+
122+
123+
def _get_hash_from_path(p):
124+
"""Extract hash from path."""
125+
basename = os.path.basename(p)
126+
return basename[0:basename.find(".story")]
127+
128+
129+
def _find_files(dl_paths, publisher, url_dict):
130+
"""Find files corresponding to urls."""
131+
if publisher == "cnn":
132+
top_dir = os.path.join(dl_paths["cnn_stories"], "cnn", "stories")
133+
elif publisher == "dm":
134+
top_dir = os.path.join(dl_paths["dm_stories"], "dailymail", "stories")
135+
else:
136+
logger.fatal("Unsupported publisher: %s", publisher)
137+
files = sorted(os.listdir(top_dir))
138+
139+
ret_files = []
140+
for p in files:
141+
if _get_hash_from_path(p) in url_dict:
142+
ret_files.append(os.path.join(top_dir, p))
143+
return ret_files
144+
145+
146+
def _subset_filenames(dl_paths, split):
147+
"""Get filenames for a particular split."""
148+
assert isinstance(dl_paths, dict), dl_paths
149+
# Get filenames for a split.
150+
if split == datasets.Split.TRAIN:
151+
urls = _get_url_hashes(dl_paths["train_urls"])
152+
elif split == datasets.Split.VALIDATION:
153+
urls = _get_url_hashes(dl_paths["val_urls"])
154+
elif split == datasets.Split.TEST:
155+
urls = _get_url_hashes(dl_paths["test_urls"])
156+
else:
157+
logger.fatal("Unsupported split: %s", split)
158+
cnn = _find_files(dl_paths, "cnn", urls)
159+
dm = _find_files(dl_paths, "dm", urls)
160+
return cnn + dm
161+
162+
163+
DM_SINGLE_CLOSE_QUOTE = "\u2019" # unicode
164+
DM_DOUBLE_CLOSE_QUOTE = "\u201d"
165+
# acceptable ways to end a sentence
166+
END_TOKENS = [
167+
".", "!", "?", "...", "'", "`", '"', DM_SINGLE_CLOSE_QUOTE,
168+
DM_DOUBLE_CLOSE_QUOTE, ")"
169+
]
170+
171+
172+
def _read_text_file(text_file):
173+
lines = []
174+
with open(text_file, "r", encoding="utf-8") as f:
175+
for line in f:
176+
lines.append(line.strip())
177+
return lines
178+
179+
180+
def _get_art_abs(story_file, tfds_version):
181+
"""Get abstract (highlights) and article from a story file path."""
182+
# Based on https://github.com/abisee/cnn-dailymail/blob/master/
183+
# make_datafiles.py
184+
185+
lines = _read_text_file(story_file)
186+
187+
# The github code lowercase the text and we removed it in 3.0.0.
188+
189+
# Put periods on the ends of lines that are missing them
190+
# (this is a problem in the dataset because many image captions don't end in
191+
# periods; consequently they end up in the body of the article as run-on
192+
# sentences)
193+
def fix_missing_period(line):
194+
"""Adds a period to a line that is missing a period."""
195+
if "@highlight" in line:
196+
return line
197+
if not line:
198+
return line
199+
if line[-1] in END_TOKENS:
200+
return line
201+
return line + " ."
202+
203+
lines = [fix_missing_period(line) for line in lines]
204+
205+
# Separate out article and abstract sentences
206+
article_lines = []
207+
highlights = []
208+
next_is_highlight = False
209+
for line in lines:
210+
if not line:
211+
continue # empty line
212+
elif line.startswith("@highlight"):
213+
next_is_highlight = True
214+
elif next_is_highlight:
215+
highlights.append(line)
216+
else:
217+
article_lines.append(line)
218+
219+
# Make article into a single string
220+
article = " ".join(article_lines)
221+
222+
if tfds_version >= "2.0.0":
223+
abstract = "\n".join(highlights)
224+
else:
225+
abstract = " ".join(highlights)
226+
227+
return article, abstract
228+
229+
230+
class CnnDailymail(datasets.GeneratorBasedBuilder):
231+
"""CNN/DailyMail non-anonymized summarization dataset."""
232+
233+
BUILDER_CONFIGS = [
234+
CnnDailymailConfig(
235+
name=str(version), description="Plain text", version=version)
236+
for version in _SUPPORTED_VERSIONS
237+
]
238+
239+
def _info(self):
240+
# Should return a datasets.DatasetInfo object
241+
return datasets.DatasetInfo(
242+
description=_DESCRIPTION,
243+
features=datasets.Features({
244+
_ARTICLE: datasets.Value("string"),
245+
_HIGHLIGHTS: datasets.Value("string"),
246+
"id": datasets.Value("string"),
247+
}),
248+
supervised_keys=None,
249+
homepage="https://github.com/abisee/cnn-dailymail",
250+
citation=_CITATION, )
251+
252+
def _vocab_text_gen(self, paths):
253+
for _, ex in self._generate_examples(paths):
254+
yield " ".join([ex[_ARTICLE], ex[_HIGHLIGHTS]])
255+
256+
def _split_generators(self, dl_manager):
257+
dl_paths = dl_manager.download_and_extract(_DL_URLS)
258+
train_files = _subset_filenames(dl_paths, datasets.Split.TRAIN)
259+
# Generate shared vocabulary
260+
261+
return [
262+
datasets.SplitGenerator(
263+
name=datasets.Split.TRAIN, gen_kwargs={"files": train_files}),
264+
datasets.SplitGenerator(
265+
name=datasets.Split.VALIDATION,
266+
gen_kwargs={
267+
"files": _subset_filenames(dl_paths,
268+
datasets.Split.VALIDATION)
269+
}, ),
270+
datasets.SplitGenerator(
271+
name=datasets.Split.TEST,
272+
gen_kwargs={
273+
"files": _subset_filenames(dl_paths, datasets.Split.TEST)
274+
}),
275+
]
276+
277+
def _generate_examples(self, files):
278+
for p in files:
279+
article, highlights = _get_art_abs(p, self.config.version)
280+
if not article or not highlights:
281+
continue
282+
fname = os.path.basename(p)
283+
yield fname, {
284+
_ARTICLE: article,
285+
_HIGHLIGHTS: highlights,
286+
"id": _get_hash_from_path(fname),
287+
}

0 commit comments

Comments
 (0)