Skip to content

Commit fbdf76a

Browse files
add '_' to initializer.glorot_normal
1 parent f654d80 commit fbdf76a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ppsci/utils/initializer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"uniform_",
3737
"normal_",
3838
"trunc_normal_",
39-
"glorot_normal",
39+
"glorot_normal_",
4040
"constant_",
4141
"ones_",
4242
"zeros_",
@@ -472,7 +472,7 @@ def conv_init_(module: nn.Layer) -> None:
472472
uniform_(module.bias, -bound, bound)
473473

474474

475-
def glorot_normal(tensor: paddle.Tensor) -> paddle.Tensor:
475+
def glorot_normal_(tensor: paddle.Tensor) -> paddle.Tensor:
476476
"""Modify tensor inplace using jax-style glorot_normal.
477477
478478
Args:
@@ -485,11 +485,11 @@ def glorot_normal(tensor: paddle.Tensor) -> paddle.Tensor:
485485
>>> import paddle
486486
>>> import ppsci
487487
>>> param = paddle.empty((128, 256), "float32")
488-
>>> param = ppsci.utils.initializer.glorot_normal(param)
488+
>>> param = ppsci.utils.initializer.glorot_normal_(param)
489489
"""
490490
assert (
491491
tensor.ndim == 2
492-
), f"glorot_normal only support 2D tensor now, but got ndim={tensor.ndim}"
492+
), f"glorot_normal_ only support 2D tensor now, but got ndim={tensor.ndim}"
493493
fin, fout = tensor.shape
494494
var = 2.0 / (fin + fout)
495495
stddev = math.sqrt(var) * 0.87962566103423978

0 commit comments

Comments
 (0)