Skip to content

Commit 2c983cc

Browse files
authored
Merge pull request #53 from Achazwl/fix-adam-torch12
fix adam API changed in torch>=1.12.0
2 parents fb7603f + e56bef4 commit 2c983cc

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

bmtrain/optim/adam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def step(self, closure=None, scale=1):
110110
[state['exp_avg']],
111111
[state["exp_avg_sq"]],
112112
[],
113-
[state["step"]],
113+
[state["step"]] if int(torch.__version__.split('.')[1]) < 12
114+
else [torch.tensor(state["step"])],
114115
amsgrad=False,
115116
beta1=group['betas'][0],
116117
beta2=group['betas'][1],

bmtrain/optim/adam_offload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def step(self, closure=None, scale=1):
136136
[state["exp_avg"]],
137137
[state["exp_avg_sq"]],
138138
[],
139-
[state["step"]],
139+
[state["step"]] if int(torch.__version__.split('.')[1]) < 12
140+
else [torch.tensor(state["step"])],
140141
amsgrad=False,
141142
beta1=beta1,
142143
beta2=beta2,

tests/test_optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def main():
2525
model2.load_state_dict(state_dict)
2626
model3.load_state_dict(state_dict)
2727

28-
model1 = model1.cuda().half()
29-
model2 = model2.cuda().half()
28+
model1 = model1.cuda()
29+
model2 = model2.cuda()
3030
model3 = model3.cuda()
3131

3232
opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3)

0 commit comments

Comments
 (0)