47
47
help = 'the number of heads in the encoder/decoder of the transformer model' )
48
48
parser .add_argument ('--dry-run' , action = 'store_true' ,
49
49
help = 'verify the code and the model' )
50
- parser .add_argument ('--accel' , action = 'store_true' ,help = 'Enables accelerated training' )
50
+ parser .add_argument ('--accel' , action = 'store_true' ,
51
+ help = 'Enables accelerated training' )
52
+ parser .add_argument ('--use-optimizer' , action = 'store_true' ,
53
+ help = 'Uses AdamW optimizer for gradient updating' )
51
54
args = parser .parse_args ()
52
55
53
56
# Set the random seed manually for reproducibility.
@@ -104,6 +107,8 @@ def batchify(data, bsz):
104
107
model = RNNModel (args .model , ntokens , args .emsize , args .nhid , args .nlayers , args .dropout , args .tied ).to (device )
105
108
106
109
criterion = nn .NLLLoss ()
110
+ if args .use_optimizer :
111
+ optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr )
107
112
108
113
###############################################################################
109
114
# Training code
@@ -167,7 +172,10 @@ def train():
167
172
data , targets = get_batch (train_data , i )
168
173
# Starting each batch, we detach the hidden state from how it was previously produced.
169
174
# If we didn't, the model would try backpropagating all the way to start of the dataset.
170
- model .zero_grad ()
175
+ if args .use_optimizer :
176
+ optimizer .zero_grad ()
177
+ else :
178
+ model .zero_grad ()
171
179
if args .model == 'Transformer' :
172
180
output = model (data )
173
181
output = output .view (- 1 , ntokens )
@@ -179,8 +187,11 @@ def train():
179
187
180
188
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
181
189
torch .nn .utils .clip_grad_norm_ (model .parameters (), args .clip )
182
- for p in model .parameters ():
183
- p .data .add_ (p .grad , alpha = - lr )
190
+ if args .use_optimizer :
191
+ optimizer .step ()
192
+ else :
193
+ for p in model .parameters ():
194
+ p .data .add_ (p .grad , alpha = - lr )
184
195
185
196
total_loss += loss .item ()
186
197
0 commit comments