Skip to content

Commit 931e5c5

Browse files
update lecun_normal in initializer.py and fix __all__ and API doc (#1187)
1 parent e29fd66 commit 931e5c5

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

docs/zh/api/utils/initializer.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,7 @@
1616
- kaiming_normal_
1717
- linear_init_
1818
- conv_init_
19+
- glorot_normal_
20+
- lecun_normal_
1921
show_root_heading: True
2022
heading_level: 3

ppsci/utils/initializer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
"kaiming_normal_",
4747
"linear_init_",
4848
"conv_init_",
49+
"glorot_normal_",
50+
"lecun_normal_",
4951
]
5052

5153

@@ -496,3 +498,32 @@ def glorot_normal_(tensor: paddle.Tensor) -> paddle.Tensor:
496498
trunc_normal_(tensor)
497499
tensor.set_value(tensor * stddev)
498500
return tensor
501+
502+
503+
def lecun_normal_(tensor: paddle.Tensor) -> paddle.Tensor:
504+
"""Modify tensor inplace using jax-style lecun_normal.
505+
506+
References:
507+
https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L480-L513
508+
509+
Args:
510+
tensor (paddle.Tensor): Paddle Tensor/Parameter.
511+
512+
Returns:
513+
paddle.Tensor: Initialized tensor.
514+
515+
Examples:
516+
>>> import paddle
517+
>>> import ppsci
518+
>>> param = paddle.empty((128, 256), "float32")
519+
>>> param = ppsci.utils.initializer.lecun_normal_(param)
520+
"""
521+
assert (
522+
tensor.ndim == 2
523+
), f"lecun_normal_ only support 2D tensor now, but got ndim={tensor.ndim}"
524+
fin, _ = tensor.shape
525+
var = 1.0 / fin
526+
stddev = math.sqrt(var) / 0.87962566103423978
527+
trunc_normal_(tensor)
528+
tensor.set_value(tensor * stddev)
529+
return tensor

0 commit comments

Comments
 (0)