File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 36
36
"uniform_" ,
37
37
"normal_" ,
38
38
"trunc_normal_" ,
39
- "glorot_normal " ,
39
+ "glorot_normal_ " ,
40
40
"constant_" ,
41
41
"ones_" ,
42
42
"zeros_" ,
@@ -472,7 +472,7 @@ def conv_init_(module: nn.Layer) -> None:
472
472
uniform_ (module .bias , - bound , bound )
473
473
474
474
475
- def glorot_normal (tensor : paddle .Tensor ) -> paddle .Tensor :
475
+ def glorot_normal_ (tensor : paddle .Tensor ) -> paddle .Tensor :
476
476
"""Modify tensor inplace using jax-style glorot_normal.
477
477
478
478
Args:
@@ -485,11 +485,11 @@ def glorot_normal(tensor: paddle.Tensor) -> paddle.Tensor:
485
485
>>> import paddle
486
486
>>> import ppsci
487
487
>>> param = paddle.empty((128, 256), "float32")
488
- >>> param = ppsci.utils.initializer.glorot_normal (param)
488
+ >>> param = ppsci.utils.initializer.glorot_normal_ (param)
489
489
"""
490
490
assert (
491
491
tensor .ndim == 2
492
- ), f"glorot_normal only support 2D tensor now, but got ndim={ tensor .ndim } "
492
+ ), f"glorot_normal_ only support 2D tensor now, but got ndim={ tensor .ndim } "
493
493
fin , fout = tensor .shape
494
494
var = 2.0 / (fin + fout )
495
495
stddev = math .sqrt (var ) * 0.87962566103423978
You can’t perform that action at this time.
0 commit comments