Skip to content

Commit d310856

Browse files
committed
Change the initialization method of tensor back to truncated normal.
1 parent 2b99941 commit d310856

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

graph_net/paddle/utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ast
33
import math
44
import numpy as np
5+
from scipy.stats import truncnorm
56
import paddle
67

78
kLiteralTensorSize = 64
@@ -197,17 +198,21 @@ def init_integer_tensor(dtype, shape, min_val, max_val, use_numpy):
197198
def init_float_tensor(shape, mean, std, min_val, max_val, use_numpy):
198199
tensor = None
199200
if use_numpy:
200-
if mean is not None and std is not None:
201+
if mean is not None and std is not None and std != 0.0:
201202
# NumPy does not support truncated normal, we simulate it here.
202-
array = np.random.normal(0, 1, shape) * std * 0.2 + mean
203-
array = np.clip(array, min_val, max_val)
203+
a = (min_val - mean) / std
204+
b = (max_val - mean) / std
205+
array = truncnorm.rvs(a, b, loc=mean, scale=std, size=shape)
204206
else:
205207
array = np.random.uniform(low=min_val, high=max_val, size=shape)
206208
tensor = paddle.to_tensor(array)
207209
else:
208210
if mean is not None and std is not None:
209-
tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean
210-
tensor = paddle.clip(tensor, min=min_val, max=max_val)
211+
tensor = paddle.empty(shape=shape, dtype="float32")
212+
initializer = paddle.nn.initializer.TruncatedNormal(
213+
mean=mean, std=std, a=min_val, b=max_val
214+
)
215+
initializer(tensor)
211216
else:
212217
tensor = paddle.uniform(
213218
shape=shape, dtype="float32", min=min_val, max=max_val

0 commit comments

Comments
 (0)