Skip to content

Commit 17f9d2e

Browse files
authored
fix qwen2 init (#1164)
1 parent f269f56 commit 17f9d2e

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mindone/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from mindspore import Parameter, Tensor, mint, nn, ops
2222
from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2323

24+
from mindone.models.utils import normal_, zeros_
2425
from mindone.transformers.cache_utils import Cache, get_max_length, get_seq_length, update
2526
from mindone.transformers.generation import GenerationMixin
2627
from mindone.transformers.mindspore_adapter import str_to_dtype
@@ -508,16 +509,15 @@ class Qwen2PreTrainedModel(MSPreTrainedModel):
508509
_supports_attention_backend = True
509510

510511
def _init_weights(self, module):
511-
# std = self.config.initializer_range
512-
# if isinstance(module, nn.Dense):
513-
# module.weight.data.normal_(mean=0.0, std=std)
514-
# if module.bias is not None:
515-
# module.bias.data.zero_()
516-
# elif isinstance(module, nn.Embedding):
517-
# module.weight.data.normal_(mean=0.0, std=std)
518-
# if module.padding_idx is not None:
519-
# module.weight.data[module.padding_idx].zero_()
520-
pass
512+
std = self.config.initializer_range
513+
if isinstance(module, nn.Dense):
514+
normal_(module.weight, mean=0.0, std=std)
515+
if module.bias is not None:
516+
zeros_(module.bias)
517+
elif isinstance(module, nn.Embedding):
518+
normal_(module.embedding_table, mean=0.0, std=std)
519+
if module.padding_idx is not None:
520+
module.embedding_table[module.padding_idx] = 0
521521

522522

523523
QWEN2_INPUTS_DOCSTRING = r"""

0 commit comments

Comments
 (0)