Skip to content

Commit 28d16ff

Browse files
authored
Add AdamW optimizer support for World Language Model example (#1380)
1 parent 993a98a commit 28d16ff

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

word_language_model/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ python main.py --accel --epochs 6 # Train a LSTM on Wikitext-2.
88
python main.py --accel --epochs 6 --tied # Train a tied LSTM on Wikitext-2.
99
python main.py --accel --tied # Train a tied LSTM on Wikitext-2for 40 epochs.
1010
python main.py --accel --epochs 6 --model Transformer --lr 5
11-
# Train a Transformer model on Wikitext-2.
11+
# Train a Transformer model on Wikitext-2.
12+
python main.py --accel --epochs 6 --model Transformer --use-optimizer --lr 0.001
13+
# Train a Transformer model with AdamW optimizer on Wikitext-2.
1214

13-
python generate.py --accel # Generate samples from the default model checkpoint.
15+
python generate.py --accel # Generate samples from the default model checkpoint.
1416
```
1517

1618
> [!NOTE]
@@ -45,6 +47,7 @@ optional arguments:
4547
path to export the final model in onnx format
4648
--nhead NHEAD the number of heads in the encoder/decoder of the transformer model
4749
--dry-run verify the code and the model
50+
--use-optimizer specify whether to use an AdamW optimizer
4851
```
4952

5053
With these arguments, a variety of models can be tested.

word_language_model/main.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
help='the number of heads in the encoder/decoder of the transformer model')
4848
parser.add_argument('--dry-run', action='store_true',
4949
help='verify the code and the model')
50-
parser.add_argument('--accel', action='store_true',help='Enables accelerated training')
50+
parser.add_argument('--accel', action='store_true',
51+
help='Enables accelerated training')
52+
parser.add_argument('--use-optimizer', action='store_true',
53+
help='Uses AdamW optimizer for gradient updating')
5154
args = parser.parse_args()
5255

5356
# Set the random seed manually for reproducibility.
@@ -104,6 +107,8 @@ def batchify(data, bsz):
104107
model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)
105108

106109
criterion = nn.NLLLoss()
110+
if args.use_optimizer:
111+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
107112

108113
###############################################################################
109114
# Training code
@@ -167,7 +172,10 @@ def train():
167172
data, targets = get_batch(train_data, i)
168173
# Starting each batch, we detach the hidden state from how it was previously produced.
169174
# If we didn't, the model would try backpropagating all the way to start of the dataset.
170-
model.zero_grad()
175+
if args.use_optimizer:
176+
optimizer.zero_grad()
177+
else:
178+
model.zero_grad()
171179
if args.model == 'Transformer':
172180
output = model(data)
173181
output = output.view(-1, ntokens)
@@ -179,8 +187,11 @@ def train():
179187

180188
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
181189
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
182-
for p in model.parameters():
183-
p.data.add_(p.grad, alpha=-lr)
190+
if args.use_optimizer:
191+
optimizer.step()
192+
else:
193+
for p in model.parameters():
194+
p.data.add_(p.grad, alpha=-lr)
184195

185196
total_loss += loss.item()
186197

0 commit comments

Comments
 (0)