|
21 | 21 | from mindspore import Parameter, Tensor, mint, nn, ops |
22 | 22 | from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
23 | 23 |
|
| 24 | +from mindone.models.utils import normal_, zeros_ |
24 | 25 | from mindone.transformers.cache_utils import Cache, get_max_length, get_seq_length, update |
25 | 26 | from mindone.transformers.generation import GenerationMixin |
26 | 27 | from mindone.transformers.mindspore_adapter import str_to_dtype |
@@ -508,16 +509,15 @@ class Qwen2PreTrainedModel(MSPreTrainedModel): |
508 | 509 | _supports_attention_backend = True |
509 | 510 |
|
510 | 511 | def _init_weights(self, module): |
511 | | - # std = self.config.initializer_range |
512 | | - # if isinstance(module, nn.Dense): |
513 | | - # module.weight.data.normal_(mean=0.0, std=std) |
514 | | - # if module.bias is not None: |
515 | | - # module.bias.data.zero_() |
516 | | - # elif isinstance(module, nn.Embedding): |
517 | | - # module.weight.data.normal_(mean=0.0, std=std) |
518 | | - # if module.padding_idx is not None: |
519 | | - # module.weight.data[module.padding_idx].zero_() |
520 | | - pass |
| 512 | + std = self.config.initializer_range |
| 513 | + if isinstance(module, nn.Dense): |
| 514 | + normal_(module.weight, mean=0.0, std=std) |
| 515 | + if module.bias is not None: |
| 516 | + zeros_(module.bias) |
| 517 | + elif isinstance(module, nn.Embedding): |
| 518 | + normal_(module.embedding_table, mean=0.0, std=std) |
| 519 | + if module.padding_idx is not None: |
| 520 | + module.embedding_table[module.padding_idx] = 0 |
521 | 521 |
|
522 | 522 |
|
523 | 523 | QWEN2_INPUTS_DOCSTRING = r""" |
|
0 commit comments