Skip to content

Commit bdc8d03

Browse files
committed
Use paddle.nn.initializer.TruncatedNormal to initilize tensor.
1 parent 391240d commit bdc8d03

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

graph_net/paddle/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,10 @@ def replay_tensor(info):
214214
else:
215215
if mean is not None and std is not None:
216216
tensor = paddle.empty(shape=shape, dtype=dtype)
217-
paddle.nn.init.trunc_normal_(
218-
tensor=tensor, mean=mean, std=std, a=min_val, b=max_val
217+
initializer = paddle.nn.initializer.TruncatedNormal(
218+
mean=mean, std=std, a=min_val, b=max_val
219219
)
220+
initializer(tensor)
220221
return tensor.to(dtype).to(device)
221222
else:
222223
return (

0 commit comments

Comments
 (0)