forked from NUSTM/RTHN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer.py
More file actions
194 lines (155 loc) · 6.86 KB
/
transformer.py
File metadata and controls
194 lines (155 loc) · 6.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# -*- encoding:utf-8 -*-
'''
@time: 2019/05/31
@author: mrzhang
@email: zhangmengran@njust.edu.cn
'''
from __future__ import print_function
import tensorflow as tf
import numpy as np
def normalize(inputs,
epsilon=1e-8,
scope="ln",
reuse=None):
with tf.variable_scope(scope, reuse=reuse):
inputs_shape = inputs.get_shape()
params_shape = inputs_shape[-1:]
mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
beta = tf.Variable(tf.zeros(params_shape))
gamma = tf.Variable(tf.ones(params_shape))
normalized = (inputs - mean) / ((variance + epsilon) ** (.5))
outputs = gamma * normalized + beta
return outputs
def multihead_attention(queries,
keys,
values,
units_query,
num_heads=8,
dropout_rate=0,
is_training=True,
scope="multihead_attention"):
'''Applies multihead attention.
Args:
queries: A 3d tensor with shape of [batch, max_doc, 2*n_hidden+pos_dim].
keys: A 3d tensor with shape of [batch, max_doc, 2*n_hidden+pos_dim].
values: A 3d tensor with shape of [batch, max_doc, 2*n_hidden].
num_units: A scalar. Attention size.
dropout_rate: A floating point number.
is_training: Boolean. Controller of mechanism for dropout.
causality: Boolean. If true, units that reference the future are masked.
num_heads: An int. Number of heads.
scope: Optional scope for `variable_scope`.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns
A 3d tensor with shape of (N, T_q, C)
'''
with tf.variable_scope(scope):
# Linear projections
Q = tf.layers.dense(queries, units_query, activation=tf.nn.relu) # (N, T_q, C)
K = tf.layers.dense(keys, units_query, activation=tf.nn.relu) # (N, T_k, C)
V = tf.layers.dense(values, units_query, activation=tf.nn.relu) # (N, T_k, C)
# Q = tf.layers.dense(queries, units_query, activation=tf.nn.swish) # (N, T_q, C)
# K = tf.layers.dense(keys, units_query, activation=tf.nn.swish) # (N, T_k, C)
# V = tf.layers.dense(values, units_query, activation=tf.nn.swish) # (N, T_k, C)
# Split and concat
Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h)
K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h)
V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h)
# Multiplication
outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k)
outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)
# Key Masking
key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k)
key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)
key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k)
# Activation
outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k)
# Query Masking
query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q)
query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q)
query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k)
outputs *= query_masks # broadcasting. (N, T_q, C)
# Dropouts
outputs = tf.layers.dropout(
outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))
# Weighted sum
outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h)
# Restore shape
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C)
# Residual connection
outputs += queries # 加上query(拼接距离)
# outputs += values # 加上value(不拼距离)
# # outputs = tf.concat([outputs, values], axis=-1) # 拼接value
# outputs = tf.concat([outputs, queries], axis=-1) # 拼接value
# # Normalize
outputs = normalize(outputs) # (N, T_q, C)
return outputs
def feedforward(inputs,
num_units=[2048, 512],
scope="multihead_attention",
reuse=None):
'''Point-wise feed forward net.
Args:
inputs: A 3d tensor with shape of [N, T, C].
num_units: A list of two integers.
scope: Optional scope for `variable_scope`.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
A 3d tensor with the same shape and dtype as inputs
'''
with tf.variable_scope(scope, reuse=reuse):
# Inner layer
params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1,
"activation": tf.nn.relu, "use_bias": True}
outputs = tf.layers.conv1d(**params)
# Readout layer
params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1,
"activation": None, "use_bias": True}
outputs = tf.layers.conv1d(**params)
# Residual connection
outputs += inputs
# Normalize
outputs = normalize(outputs)
return outputs
def feedforward_1(inputs, num_units, out_units):
u1 = tf.layers.dense(inputs, num_units, use_bias=True) # (N, T_q, C)
u2 = tf.nn.relu(u1)
outputs = tf.layers.dense(u2, out_units, use_bias=True)
# outputs += values # 加上value(不拼距离)
# u1 = tf.layers.dense(inputs, num_units, use_bias=True) # (N, T_q, C)
# u2 = tf.reduce_max(0, u1)
# outputs = tf.layers.dense(u2, num_units, use_bias=True)
# outputs = normalize(outputs)
return outputs
def label_smoothing(inputs, epsilon=0.1):
'''Applies label smoothing. See https://arxiv.org/abs/1512.00567.
Args:
inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.
epsilon: Smoothing rate.
For example,
```
import tensorflow as tf
inputs = tf.convert_to_tensor([[[0, 0, 1],
[0, 1, 0],
[1, 0, 0]],
[[1, 0, 0],
[1, 0, 0],
[0, 1, 0]]], tf.float32)
outputs = label_smoothing(inputs)
with tf.Session() as sess:
print(sess.run([outputs]))
>>
[array([[[ 0.03333334, 0.03333334, 0.93333334],
[ 0.03333334, 0.93333334, 0.03333334],
[ 0.93333334, 0.03333334, 0.03333334]],
[[ 0.93333334, 0.03333334, 0.03333334],
[ 0.93333334, 0.03333334, 0.03333334],
[ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)]
```
'''
K = inputs.get_shape().as_list()[-1] # number of channels
return ((1 - epsilon) * inputs) + (epsilon / K)