File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change 16
16
- kaiming_normal_
17
17
- linear_init_
18
18
- conv_init_
19
+ - glorot_normal_
20
+ - lecun_normal_
19
21
show_root_heading: True
20
22
heading_level: 3
Original file line number Diff line number Diff line change 46
46
"kaiming_normal_" ,
47
47
"linear_init_" ,
48
48
"conv_init_" ,
49
+ "glorot_normal_" ,
50
+ "lecun_normal_" ,
49
51
]
50
52
51
53
@@ -496,3 +498,32 @@ def glorot_normal_(tensor: paddle.Tensor) -> paddle.Tensor:
496
498
trunc_normal_ (tensor )
497
499
tensor .set_value (tensor * stddev )
498
500
return tensor
501
+
502
+
503
+ def lecun_normal_ (tensor : paddle .Tensor ) -> paddle .Tensor :
504
+ """Modify tensor inplace using jax-style lecun_normal.
505
+
506
+ References:
507
+ https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L480-L513
508
+
509
+ Args:
510
+ tensor (paddle.Tensor): Paddle Tensor/Parameter.
511
+
512
+ Returns:
513
+ paddle.Tensor: Initialized tensor.
514
+
515
+ Examples:
516
+ >>> import paddle
517
+ >>> import ppsci
518
+ >>> param = paddle.empty((128, 256), "float32")
519
+ >>> param = ppsci.utils.initializer.lecun_normal_(param)
520
+ """
521
+ assert (
522
+ tensor .ndim == 2
523
+ ), f"lecun_normal_ only support 2D tensor now, but got ndim={ tensor .ndim } "
524
+ fin , _ = tensor .shape
525
+ var = 1.0 / fin
526
+ stddev = math .sqrt (var ) / 0.87962566103423978
527
+ trunc_normal_ (tensor )
528
+ tensor .set_value (tensor * stddev )
529
+ return tensor
You can’t perform that action at this time.
0 commit comments