-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtransformer_funcs.py
More file actions
191 lines (143 loc) · 6.99 KB
/
transformer_funcs.py
File metadata and controls
191 lines (143 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import numpy as np
import tensorflow as tf
import numpy as np
from attenvis import AttentionVis
av = AttentionVis()
@av.att_mat_func
def Attention_Matrix(K, Q, use_mask=False):
"""
This functions runs a single attention head.
:param K: is [batch_size x window_size_keys x embedding_size]
:param Q: is [batch_size x window_size_queries x embedding_size]
:return: attention matrix
"""
window_size_queries = Q.get_shape()[1] # window size of queries
window_size_keys = K.get_shape()[1] # window size of keys
mask = tf.convert_to_tensor(value=np.transpose(np.tril(np.ones(
(window_size_queries, window_size_keys))*np.NINF, -1), (1, 0)), dtype=tf.float32)
atten_mask = tf.tile(tf.reshape(
mask, [-1, window_size_queries, window_size_keys]), [tf.shape(input=K)[0], 1, 1])
dk = Q.get_shape()[0] # embedding_size
K = tf.transpose(K, perm=[0, 2, 1])
ktimesq = tf.linalg.matmul(Q, K)
if use_mask == True:
ktimesq = ktimesq + atten_mask
ktimesq = ktimesq / ((dk)**0.5)
atten_mat = tf.nn.softmax(ktimesq)
return atten_mat
class Atten_Head(tf.keras.layers.Layer):
def __init__(self, input_size, output_size, use_mask):
super(Atten_Head, self).__init__()
self.use_mask = use_mask
self.K_weight = self.add_weight(shape=[input_size, output_size])
self.V_weight = self.add_weight(shape=[input_size, output_size])
self.Q_weight = self.add_weight(shape=[input_size, output_size])
@tf.function
def call(self, inputs_for_keys, inputs_for_values, inputs_for_queries):
"""
This functions runs a single attention head.
:param inputs_for_keys: tensor of [batch_size x [WINDOW_SIZE x input_size ]
:param inputs_for_values: tensor of [batch_size x [WINDOW_SIZE x input_size ]
:param inputs_for_queries: tensor of [batch_size x [WINDOW_SIZE x input_size ]
:return: tensor of [BATCH_SIZE x (WINDOW_SIZE x output_size ]
"""
K_var = tf.tensordot(inputs_for_keys, self.K_weight, [[2], [0]])
V_var = tf.tensordot(inputs_for_values, self.V_weight, [[2], [0]])
Q_var = tf.tensordot(inputs_for_queries, self.Q_weight, [[2], [0]])
atten_matrix = Attention_Matrix(K_var, Q_var, self.use_mask)
return tf.linalg.matmul(atten_matrix, V_var)
class Multi_Headed(tf.keras.layers.Layer):
def __init__(self, emb_sz, use_mask):
super(Multi_Headed, self).__init__()
# TODO:
# Initialize heads
@tf.function
def call(self, inputs_for_keys, inputs_for_values, inputs_for_queries):
"""
This functions runs a multiheaded attention layer.
- Splits data for 3 different heads of size embed_sz/3
- Create three different attention heads
- Concatenate the outputs of these heads together
- Apply a linear layer
:param inputs_for_keys: tensor of [batch_size x [WINDOW_SIZE x input_size ]
:param inputs_for_values: tensor of [batch_size x [WINDOW_SIZE x input_size ]
:param inputs_for_queries: tensor of [batch_size x [WINDOW_SIZE x input_size ]
:return: tensor of [BATCH_SIZE x (WINDOW_SIZE x output_size ]
"""
return None
class Feed_Forwards(tf.keras.layers.Layer):
def __init__(self, emb_sz):
super(Feed_Forwards, self).__init__()
self.layer_1 = tf.keras.layers.Dense(emb_sz, activation='relu')
self.layer_2 = tf.keras.layers.Dense(emb_sz)
@tf.function
def call(self, inputs):
"""
This functions creates a feed forward network as described in 3.3
https://arxiv.org/pdf/1706.03762.pdf
- Two linear layers with relu between them
:param inputs: input tensor [batch_size x window_size x embedding_size]
:return: tensor [batch_size x window_size x embedding_size]
"""
layer_1_out = self.layer_1(inputs)
layer_2_out = self.layer_2(layer_1_out)
return layer_2_out
class Transformer_Block(tf.keras.layers.Layer):
def __init__(self, emb_sz, is_decoder, multi_headed=False):
super(Transformer_Block, self).__init__()
self.ff_layer = Feed_Forwards(emb_sz)
self.self_atten = Atten_Head(emb_sz, emb_sz, use_mask=is_decoder) if not multi_headed else Multi_Headed(
emb_sz, use_mask=is_decoder)
self.is_decoder = is_decoder
if self.is_decoder:
self.self_context_atten = Atten_Head(
emb_sz, emb_sz, use_mask=False) if not multi_headed else Multi_Headed(emb_sz, use_mask=False)
self.layer_norm = tf.keras.layers.LayerNormalization(axis=-1)
@tf.function
def call(self, inputs, context=None):
"""
This functions calls a transformer block.
There are two possibilities for when this function is called.
- if self.is_decoder == False, then:
1) compute unmasked attention on the inputs
2) residual connection and layer normalization
3) feed forward layer
4) residual connection and layer normalization
- if self.is_decoder == True, then:
1) compute MASKED attention on the inputs
2) residual connection and layer normalization
3) computed UNMASKED attention using context
4) residual connection and layer normalization
5) feed forward layer
6) residual layer and layer normalization
:param inputs: tensor of [BATCH_SIZE x (WINDOW_SIZE x EMBEDDING_SIZE ]
:context: tensor of [BATCH_SIZE x WINDOW_SIZE x EMBEDDING_SIZE ] or None
default=None, This is context from the encoder to be used as Keys and Values in self-attention function
"""
with av.trans_block(self.is_decoder):
atten_out = self.self_atten(inputs, inputs, inputs)
atten_out += inputs
atten_normalized = self.layer_norm(atten_out)
if self.is_decoder:
assert context is not None, "Decoder blocks require context"
context_atten_out = self.self_context_atten(
context, context, atten_normalized)
context_atten_out += atten_normalized
atten_normalized = self.layer_norm(context_atten_out)
ff_out = self.ff_layer(atten_normalized)
ff_out += atten_normalized
ff_norm = self.layer_norm(ff_out)
return tf.nn.relu(ff_norm)
class Position_Encoding_Layer(tf.keras.layers.Layer):
def __init__(self, window_sz, emb_sz):
super(Position_Encoding_Layer, self).__init__()
self.positional_embeddings = self.add_weight(
"pos_embed", shape=[window_sz, emb_sz])
@tf.function
def call(self, x):
"""
Adds positional embeddings to word embeddings.
:param x: [BATCH_SIZE x (WINDOW_SIZE x EMBEDDING_SIZE ] the input embeddings fed to the encoder
:return: [BATCH_SIZE x (WINDOW_SIZE x EMBEDDING_SIZE ] new word embeddings with added positional encodings
"""
return x+self.positional_embeddings