-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTransformerEncoder.py
More file actions
257 lines (193 loc) · 10.5 KB
/
TransformerEncoder.py
File metadata and controls
257 lines (193 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import numpy as np
import torch
import torch.nn as nn
#################################################
### TRANSFORMER CLASS
#################################################
'''
Here we define the transformer class used in obtaining the AnomalyCLR representations.
It is based on a standard pytorch TransformerEncoder set-up, with masking.
'''
class Transformer( nn.Module ):
# define and intialize the structure of the neural network
def __init__( self, input_dim, model_dim, output_dim, n_heads, dim_feedforward, n_layers, learning_rate, n_head_layers=2, head_norm=False, dropout=0.1, opt="adam" ):
super().__init__()
# define hyperparameters
self.input_dim = input_dim
self.model_dim = model_dim
self.output_dim = output_dim
self.n_heads = n_heads
self.dim_feedforward = dim_feedforward
self.n_layers = n_layers
self.learning_rate = learning_rate
self.n_head_layers = n_head_layers
self.head_norm = head_norm
self.dropout = dropout
# define subnetworks
self.embedding = nn.Linear(input_dim, model_dim)
self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(model_dim, n_heads, dim_feedforward=dim_feedforward, dropout=dropout), n_layers)
# define the head network using information from the input
if n_head_layers == 0:
self.head_layers = []
else:
if head_norm: self.norm_layers = nn.ModuleList([nn.LayerNorm(model_dim)])
self.head_layers = nn.ModuleList([nn.Linear(model_dim, output_dim)])
for i in range(n_head_layers-1):
if head_norm: self.norm_layers.append(nn.LayerNorm(output_dim))
self.head_layers.append(nn.Linear(output_dim, output_dim))
# set-up the optimiser used during training
if opt == "adam":
self.optimizer = torch.optim.Adam( self.parameters(), lr=self.learning_rate )
if opt == "sgdca" or opt == "sgdslr" or opt == "sgd":
self.optimizer = torch.optim.SGD( self.parameters(), lr=self.learning_rate, momentum=0.9 )
#################################################
#################################################
# the forward function determines how data is passed through the network
def forward(self, inpt, mask=None, use_mask=False, use_continuous_mask=False, mult_reps=False):
'''
input dims : (batch_size, n_constit, 7)
but transformer expects (n_constit, batch_size, 7) so we need to transpose
if use_mask is True, will mask out all inputs with pT=0
'''
# check that only one masking is selected in the input
assert not (use_mask and use_continuous_mask)
# make a copy of the inputs
x = inpt + 0.
# make a mask and send it to the device
if use_mask: pT_zero = x[:,:,0] == 0 # shape: (batch_size,n_constit)
if use_continuous_mask: pT = x[:,:,0] # shape: (batch_size,n_constit)
if use_mask:
mask = self.make_mask(pT_zero).to(x.device)
elif use_continuous_mask:
mask = self.make_continuous_mask(pT).to(x.device)
else:
mask = None
# transpose data for input to the transformer
x = torch.transpose(x, 0, 1) # shape: (n_constit,batch_size,7)
# pass data through the embedding layer
x = self.embedding(x) # shape: (n_constit,batch_size,model_dim)
# pass data through the transformer, including the mask
x = self.transformer(x, mask=mask) # shape: (n_constit,batch_size,output_dim)
# we need to again implement the masking on the outputs of the transformer
if use_mask:
x[torch.transpose(pT_zero, 0, 1)] = 0
elif use_continuous_mask:
x *= torch.transpose(pT, 0, 1)[:,:,None]
# sum over the representation vector for each constituent
x = x.sum(0) # (batch_size, model_dim)
# pass data through the head network and return the output
return self.head(x, mult_reps)
#################################################
#################################################
# the head network comes after the transformer network, and is where we obtain the representations from
def head(self, x, mult_reps):
'''
if mult_reps=True the head returns a representation vector for each layer of the head networke
input: x shape : (batchsize,model_dim)
mult_reps : boolean
output: reps shape : (batchsize,output_dim) for mult_reps=False
reps shape : (batchsize,number_of_reps,output_dim) for mult_reps=True
'''
# define the activation function
relu = nn.ReLU()
# return representations from multiple layers for evaluation
if mult_reps == True:
# if the head has multiple layers, we loop through them
if self.n_head_layers > 0:
# define the reps object with the correct shape
reps = torch.empty(x.shape[0], self.n_head_layers+1, self.output_dim)
# the first representation is the one coming straight from the transformer
reps[:, 0] = x
# loop through the head layers
for i, layer in enumerate(self.head_layers):
# only apply layer norm on head if head_norm=True
if self.head_norm: x = self.norm_layers[i](x)
# pass data through the activation function and each layer
x = relu(x)
x = layer(x)
# append the output of each layer to the reps object we created
reps[:, i+1] = x # shape (n_head_layers, output_dim)
# return the representations
return reps
# if there are no head layers, i.e. no head network -> just return x in a list with dimension 1
else:
reps = x[:, None, :] # shape (batchsize, 1, model_dim)
# return the representations
return reps
# while training the CLR network the loss is computed on the whole network outputs
# so when mult_reps=False we return only last representation from the head network
else:
# pass the data through the head network, this will do nothing if n_head_layers=0
for i, layer in enumerate(self.head_layers):
if self.head_norm: x = self.norm_layers[i](x)
x = relu(x)
x = layer(x) # shape either (model_dim) if no head, or (output_dim) if head exists
# return the final representation only
return x
#################################################
#################################################
# when we want to pass a large amount of data through the transformer to get representations we need to do it in a batch-wise way
# this function does that
def forward_batchwise( self, x, batch_size, use_mask=False, use_continuous_mask=False):
# set the device
device = next(self.parameters()).device
# turn off gradient calculation to save computation
with torch.no_grad():
# set number of reps based on n_head_layers
if self.n_head_layers == 0:
rep_dim = self.model_dim
number_of_reps = 1
elif self.n_head_layers > 0:
rep_dim = self.output_dim
number_of_reps = self.n_head_layers+1
# initiate the output array
out = torch.empty( x.size(0), number_of_reps, rep_dim )
# get idx lists of different batches
idx_list = torch.split( torch.arange( x.size(0) ), batch_size )
# iterate through batch idx lists
for idx in idx_list:
# pass each batch through the network
output = self(x[idx].to(device), use_mask=use_mask, use_continuous_mask=use_continuous_mask, mult_reps=True).detach().cpu()
# update the output array with the representations
out[idx] = output
# return the representations
return out
#################################################
#################################################
# we define a binary mask for the transformer
def make_mask(self, pT_zero):
'''
Input : batch of bools of whether pT=0, shape (batchsize, n_constit)
Output : mask for transformer model which masks out constituents with pT=0, shape (batchsize*n_transformer_heads, n_constit, n_constit)
mask is added to attention output before softmax: 0 means value is unchanged, -inf means it will be masked
'''
# get number of constituents
n_constit = pT_zero.size(1)
# produce the mask with the correct shape for the transformer architecture
pT_zero = torch.repeat_interleave(pT_zero, self.n_heads, axis=0)
pT_zero = torch.repeat_interleave(pT_zero[:,None], n_constit, axis=1)
mask = torch.zeros(pT_zero.size(0), n_constit, n_constit)
# the mass is additive in the attention softmax, so masked requires an addition of -np.inf
mask[pT_zero] = -np.inf
# return the mask
return mask
#################################################
#################################################
# we define a continuous pT dependent masking
def make_continuous_mask(self, pT):
'''
Input: batch of pT values, shape (batchsize, n_constit)
Output: mask for transformer model: -1/pT, shape (batchsize*n_transformer_heads, n_constit, n_constit)
mask is added to attention output before softmax: 0 means value is unchanged, -inf means it will be masked
intermediate values mean it is partly masked
this function implements IR safety in the transformer
'''
# get the number of constituents
n_constit = pT.size(1)
# create mask with the correct shape
pT_reshape = torch.repeat_interleave(pT, self.n_heads, axis=0)
pT_reshape = torch.repeat_interleave(pT_reshape[:,None], n_constit, axis=1)
# the mask affects all consituents proportional to their pT^0.5, in this case, can be modified
mask = 0.5*torch.log( pT_reshape )
# return the mask
return mask