1
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import math
15
16
import os
16
17
import re
17
- import math
18
- import random
19
18
from typing import Iterable
20
19
21
- import numpy as np
22
- import paddle
23
20
from paddle .dataset .common import md5file
24
21
from paddle .utils .download import get_path_from_url
25
22
23
+ from ..data import JiebaTokenizer , Vocab
26
24
from ..utils .env import DATA_HOME
27
- from ..data import Vocab , JiebaTokenizer
28
25
29
26
30
27
class BaseAugment (object ):
@@ -44,7 +41,7 @@ class BaseAugment(object):
44
41
Maximum number of augmented words in sequences.
45
42
"""
46
43
47
- def __init__ (self , create_n , aug_n = None , aug_percent = 0.02 , aug_min = 1 , aug_max = 10 ):
44
+ def __init__ (self , create_n = 1 , aug_n = None , aug_percent = 0.1 , aug_min = 1 , aug_max = 10 , vocab = "vocab" ):
48
45
self ._DATA = {
49
46
"stop_words" : (
50
47
"stopwords.txt" ,
@@ -56,24 +53,49 @@ def __init__(self, create_n, aug_n=None, aug_percent=0.02, aug_min=1, aug_max=10
56
53
"25c2d41aec5a6d328a65c1995d4e4c2e" ,
57
54
"https://bj.bcebos.com/paddlenlp/data/baidu_encyclopedia_w2v_vocab.json" ,
58
55
),
56
+ "test_vocab" : (
57
+ "test_vocab.json" ,
58
+ "1d2fce1c80a4a0ec2e90a136f339ab88" ,
59
+ "https://bj.bcebos.com/paddlenlp/data/test_vocab.json" ,
60
+ ),
59
61
"word_synonym" : (
60
62
"word_synonym.json" ,
61
63
"aaa9f864b4af4123bce4bf138a5bfa0d" ,
62
64
"https://bj.bcebos.com/paddlenlp/data/word_synonym.json" ,
63
65
),
66
+ "word_embedding" : (
67
+ "word_embedding.json" ,
68
+ "534aa4ad274def4deff585cefd8ead32" ,
69
+ "https://bj.bcebos.com/paddlenlp/data/word_embedding.json" ,
70
+ ),
64
71
"word_homonym" : (
65
72
"word_homonym.json" ,
66
73
"a578c04201a697e738f6a1ad555787d5" ,
67
74
"https://bj.bcebos.com/paddlenlp/data/word_homonym.json" ,
68
75
),
76
+ "char_homonym" : (
77
+ "char_homonym.json" ,
78
+ "dd98d5d5d32a3d3dd45c8f7ca503c7df" ,
79
+ "https://bj.bcebos.com/paddlenlp/data/char_homonym.json" ,
80
+ ),
81
+ "char_antonym" : (
82
+ "char_antonym.json" ,
83
+ "f892f5dce06f17d19949ebcbe0ed52b7" ,
84
+ "https://bj.bcebos.com/paddlenlp/data/char_antonym.json" ,
85
+ ),
86
+ "word_antonym" : (
87
+ "word_antonym.json" ,
88
+ "cbea11fa99fbe9d07e8185750b37e84a" ,
89
+ "https://bj.bcebos.com/paddlenlp/data/word_antonym.json" ,
90
+ ),
69
91
}
70
92
self .stop_words = self ._get_data ("stop_words" )
71
93
self .aug_n = aug_n
72
94
self .aug_percent = aug_percent
73
95
self .aug_min = aug_min
74
96
self .aug_max = aug_max
75
97
self .create_n = create_n
76
- self .vocab = Vocab .from_json (self ._load_file (" vocab" ))
98
+ self .vocab = Vocab .from_json (self ._load_file (vocab ))
77
99
self .tokenizer = JiebaTokenizer (self .vocab )
78
100
self .loop = 5
79
101
@@ -150,7 +172,7 @@ def augment(self, sequences, num_thread=1):
150
172
# Single Thread
151
173
if num_thread == 1 :
152
174
if isinstance (sequences , str ):
153
- return self ._augment (sequences )
175
+ return [ self ._augment (sequences )]
154
176
else :
155
177
output = []
156
178
for sequence in sequences :
@@ -161,3 +183,59 @@ def augment(self, sequences, num_thread=1):
161
183
162
184
def _augment (self , sequence ):
163
185
raise NotImplementedError
186
+
187
+
188
+ class FileAugment (object ):
189
+ """
190
+ File data augmentation
191
+
192
+ Args:
193
+ strategies (List):
194
+ List of augmentation strategies.
195
+ """
196
+
197
+ def __init__ (self , strategies ):
198
+ self .strategies = strategies
199
+
200
+ def augment (self , input_file , output_file = "aug.txt" , separator = None , separator_id = 0 ):
201
+ output_sequences = []
202
+ sequences = []
203
+
204
+ input_sequences = self .file_read (input_file )
205
+
206
+ if separator :
207
+ for input_sequence in input_sequences :
208
+ sequences .append (input_sequence .split (separator )[separator_id ])
209
+ else :
210
+ sequences = input_sequences
211
+
212
+ for strategy in self .strategies :
213
+ aug_sequences = strategy .augment (sequences )
214
+ if separator :
215
+ for aug_sequence , input_sequence in zip (aug_sequences , input_sequences ):
216
+ input_items = input_sequence .split (separator )
217
+ for s in aug_sequence :
218
+ input_items [separator_id ] = s
219
+ output_sequences .append (separator .join (input_items ))
220
+ else :
221
+ for aug_sequence in aug_sequences :
222
+ output_sequences += aug_sequence
223
+
224
+ if output_file :
225
+ self .file_write (output_sequences , output_file )
226
+
227
+ return output_sequences
228
+
229
+ def file_read (self , input_file ):
230
+ input_sequences = []
231
+ with open (input_file , "r" , encoding = "utf-8" ) as f :
232
+ for line in f :
233
+ input_sequences .append (line .strip ())
234
+ f .close ()
235
+ return input_sequences
236
+
237
+ def file_write (self , output_sequences , output_file ):
238
+ with open (output_file , "w" , encoding = "utf-8" ) as f :
239
+ for output_sequence in output_sequences :
240
+ f .write (output_sequence + "\n " )
241
+ f .close ()
0 commit comments