Skip to content

Commit 58eb2be

Browse files
authored
add sentence & character level data augmentation api (#4194)
* add_sentence_dataaug * add_char_augmentation * add_antonym * add_file_augmentation * add_test
1 parent 20adadc commit 58eb2be

15 files changed

+3067
-974
lines changed

docs/dataaug.md

Lines changed: 907 additions & 273 deletions
Large diffs are not rendered by default.

paddlenlp/dataaug/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .word_substitute import *
16-
from .word_insert import *
17-
from .word_delete import *
18-
from .word_swap import *
15+
from .base_augment import FileAugment
16+
from .char import *
17+
from .sentence import *
18+
from .word import *

paddlenlp/dataaug/base_augment.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,19 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
import os
1617
import re
17-
import math
18-
import random
1918
from typing import Iterable
2019

21-
import numpy as np
22-
import paddle
2320
from paddle.dataset.common import md5file
2421
from paddle.utils.download import get_path_from_url
2522

23+
from ..data import JiebaTokenizer, Vocab
2624
from ..utils.env import DATA_HOME
27-
from ..data import Vocab, JiebaTokenizer
2825

2926

3027
class BaseAugment(object):
@@ -44,7 +41,7 @@ class BaseAugment(object):
4441
Maximum number of augmented words in sequences.
4542
"""
4643

47-
def __init__(self, create_n, aug_n=None, aug_percent=0.02, aug_min=1, aug_max=10):
44+
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10, vocab="vocab"):
4845
self._DATA = {
4946
"stop_words": (
5047
"stopwords.txt",
@@ -56,24 +53,49 @@ def __init__(self, create_n, aug_n=None, aug_percent=0.02, aug_min=1, aug_max=10
5653
"25c2d41aec5a6d328a65c1995d4e4c2e",
5754
"https://bj.bcebos.com/paddlenlp/data/baidu_encyclopedia_w2v_vocab.json",
5855
),
56+
"test_vocab": (
57+
"test_vocab.json",
58+
"1d2fce1c80a4a0ec2e90a136f339ab88",
59+
"https://bj.bcebos.com/paddlenlp/data/test_vocab.json",
60+
),
5961
"word_synonym": (
6062
"word_synonym.json",
6163
"aaa9f864b4af4123bce4bf138a5bfa0d",
6264
"https://bj.bcebos.com/paddlenlp/data/word_synonym.json",
6365
),
66+
"word_embedding": (
67+
"word_embedding.json",
68+
"534aa4ad274def4deff585cefd8ead32",
69+
"https://bj.bcebos.com/paddlenlp/data/word_embedding.json",
70+
),
6471
"word_homonym": (
6572
"word_homonym.json",
6673
"a578c04201a697e738f6a1ad555787d5",
6774
"https://bj.bcebos.com/paddlenlp/data/word_homonym.json",
6875
),
76+
"char_homonym": (
77+
"char_homonym.json",
78+
"dd98d5d5d32a3d3dd45c8f7ca503c7df",
79+
"https://bj.bcebos.com/paddlenlp/data/char_homonym.json",
80+
),
81+
"char_antonym": (
82+
"char_antonym.json",
83+
"f892f5dce06f17d19949ebcbe0ed52b7",
84+
"https://bj.bcebos.com/paddlenlp/data/char_antonym.json",
85+
),
86+
"word_antonym": (
87+
"word_antonym.json",
88+
"cbea11fa99fbe9d07e8185750b37e84a",
89+
"https://bj.bcebos.com/paddlenlp/data/word_antonym.json",
90+
),
6991
}
7092
self.stop_words = self._get_data("stop_words")
7193
self.aug_n = aug_n
7294
self.aug_percent = aug_percent
7395
self.aug_min = aug_min
7496
self.aug_max = aug_max
7597
self.create_n = create_n
76-
self.vocab = Vocab.from_json(self._load_file("vocab"))
98+
self.vocab = Vocab.from_json(self._load_file(vocab))
7799
self.tokenizer = JiebaTokenizer(self.vocab)
78100
self.loop = 5
79101

@@ -150,7 +172,7 @@ def augment(self, sequences, num_thread=1):
150172
# Single Thread
151173
if num_thread == 1:
152174
if isinstance(sequences, str):
153-
return self._augment(sequences)
175+
return [self._augment(sequences)]
154176
else:
155177
output = []
156178
for sequence in sequences:
@@ -161,3 +183,59 @@ def augment(self, sequences, num_thread=1):
161183

162184
def _augment(self, sequence):
163185
raise NotImplementedError
186+
187+
188+
class FileAugment(object):
189+
"""
190+
File data augmentation
191+
192+
Args:
193+
strategies (List):
194+
List of augmentation strategies.
195+
"""
196+
197+
def __init__(self, strategies):
198+
self.strategies = strategies
199+
200+
def augment(self, input_file, output_file="aug.txt", separator=None, separator_id=0):
201+
output_sequences = []
202+
sequences = []
203+
204+
input_sequences = self.file_read(input_file)
205+
206+
if separator:
207+
for input_sequence in input_sequences:
208+
sequences.append(input_sequence.split(separator)[separator_id])
209+
else:
210+
sequences = input_sequences
211+
212+
for strategy in self.strategies:
213+
aug_sequences = strategy.augment(sequences)
214+
if separator:
215+
for aug_sequence, input_sequence in zip(aug_sequences, input_sequences):
216+
input_items = input_sequence.split(separator)
217+
for s in aug_sequence:
218+
input_items[separator_id] = s
219+
output_sequences.append(separator.join(input_items))
220+
else:
221+
for aug_sequence in aug_sequences:
222+
output_sequences += aug_sequence
223+
224+
if output_file:
225+
self.file_write(output_sequences, output_file)
226+
227+
return output_sequences
228+
229+
def file_read(self, input_file):
230+
input_sequences = []
231+
with open(input_file, "r", encoding="utf-8") as f:
232+
for line in f:
233+
input_sequences.append(line.strip())
234+
f.close()
235+
return input_sequences
236+
237+
def file_write(self, output_sequences, output_file):
238+
with open(output_file, "w", encoding="utf-8") as f:
239+
for output_sequence in output_sequences:
240+
f.write(output_sequence + "\n")
241+
f.close()

0 commit comments

Comments
 (0)