Skip to content

Commit 5129f0e

Browse files
committed
[fix] dtype & lr
1 parent aa7dc0f commit 5129f0e

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

model/model_minimind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def forward(self, x):
313313
flat_topk_idx = topk_idx.view(-1)
314314
if self.training:
315315
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
316-
y = torch.empty_like(x, dtype=torch.float16)
316+
y = torch.empty_like(x, dtype=x.dtype)
317317
for i, expert in enumerate(self.experts):
318318
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
319319
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)

trainer/trainer_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def Logger(content):
2323

2424

2525
def get_lr(current_step, total_steps, lr):
26-
min_lr = lr / 10
27-
return min_lr + 0.5 * (lr - min_lr) * (1 + math.cos(math.pi * current_step / total_steps))
26+
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
2827

2928

3029
def init_distributed_mode():

0 commit comments

Comments
 (0)