-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_utils.py
More file actions
113 lines (86 loc) · 3.04 KB
/
data_utils.py
File metadata and controls
113 lines (86 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import gzip
from util import load_dict
# Extra vocabulary symbols
_GO = '_GO'
EOS = '_EOS' # also function as PAD
UNK = '_UNK'
extra_tokens = [_GO, EOS, UNK]
start_token = extra_tokens.index(_GO) # start_token = 0
end_token = extra_tokens.index(EOS) # end_token = 1
unk_token = extra_tokens.index(UNK)
def fopen(filename, mode='r'):
if filename.endswith('.gz'):
return gzip.open(filename, mode)
return open(filename, mode,encoding='utf8')
def load_inverse_dict(dict_path):
orig_dict = load_dict(dict_path)
idict = {}
for words, idx in orig_dict.items():
idict[idx] = words
return idict
def seq2words(seq, inverse_target_dictionary):
words = []
for w in seq:
if w == end_token:
break
if w in inverse_target_dictionary:
words.append(inverse_target_dictionary[w])
else:
words.append(UNK)
return ' '.join(words)
# batch preparation of a given sequence
def prepare_batch(seqs_x, maxlen=None):
# seqs_x: a list of sentences
lengths_x = [len(s) for s in seqs_x]
if maxlen is not None:
new_seqs_x = []
new_lengths_x = []
for l_x, s_x in zip(lengths_x, seqs_x):
if l_x <= maxlen:
new_seqs_x.append(s_x)
new_lengths_x.append(l_x)
lengths_x = new_lengths_x
seqs_x = new_seqs_x
if len(lengths_x) < 1:
return None, None
batch_size = len(seqs_x)
x_lengths = np.array(lengths_x)
maxlen_x = np.max(x_lengths)
x = np.ones((batch_size, maxlen_x)).astype('int32') * end_token
for idx, s_x in enumerate(seqs_x):
x[idx, :lengths_x[idx]] = s_x
return x, x_lengths
# batch preparation of a given sequence pair for training
def prepare_train_batch(seqs_x, seqs_y, maxlen=None):
# seqs_x, seqs_y: a list of sentences
lengths_x = [len(s) for s in seqs_x]
lengths_y = [len(s) for s in seqs_y]
if maxlen is not None:
new_seqs_x = []
new_seqs_y = []
new_lengths_x = []
new_lengths_y = []
for l_x, s_x, l_y, s_y in zip(lengths_x, seqs_x, lengths_y, seqs_y):
if l_x <= maxlen and l_y <= maxlen:
new_seqs_x.append(s_x)
new_lengths_x.append(l_x)
new_seqs_y.append(s_y)
new_lengths_y.append(l_y)
lengths_x = new_lengths_x
seqs_x = new_seqs_x
lengths_y = new_lengths_y
seqs_y = new_seqs_y
if len(lengths_x) < 1 or len(lengths_y) < 1:
return None, None, None, None
batch_size = len(seqs_x)
x_lengths = np.array(lengths_x)
y_lengths = np.array(lengths_y)
maxlen_x = np.max(x_lengths)
maxlen_y = np.max(y_lengths)
x = np.ones((batch_size, maxlen_x)).astype('int32') * end_token
y = np.ones((batch_size, maxlen_y)).astype('int32') * end_token
for idx, [s_x, s_y] in enumerate(zip(seqs_x, seqs_y)):
x[idx, :lengths_x[idx]] = s_x
y[idx, :lengths_y[idx]] = s_y
return x, x_lengths, y, y_lengths