|
2 | 2 | import ast |
3 | 3 | import math |
4 | 4 | import numpy as np |
| 5 | +from scipy.stats import truncnorm |
5 | 6 | import paddle |
6 | 7 |
|
7 | 8 | kLiteralTensorSize = 64 |
@@ -197,17 +198,21 @@ def init_integer_tensor(dtype, shape, min_val, max_val, use_numpy): |
197 | 198 | def init_float_tensor(shape, mean, std, min_val, max_val, use_numpy): |
198 | 199 | tensor = None |
199 | 200 | if use_numpy: |
200 | | - if mean is not None and std is not None: |
| 201 | + if mean is not None and std is not None and std != 0.0: |
201 | 202 | # NumPy does not support truncated normal, we simulate it here. |
202 | | - array = np.random.normal(0, 1, shape) * std * 0.2 + mean |
203 | | - array = np.clip(array, min_val, max_val) |
| 203 | + a = (min_val - mean) / std |
| 204 | + b = (max_val - mean) / std |
| 205 | + array = truncnorm.rvs(a, b, loc=mean, scale=std, size=shape) |
204 | 206 | else: |
205 | 207 | array = np.random.uniform(low=min_val, high=max_val, size=shape) |
206 | 208 | tensor = paddle.to_tensor(array) |
207 | 209 | else: |
208 | 210 | if mean is not None and std is not None: |
209 | | - tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean |
210 | | - tensor = paddle.clip(tensor, min=min_val, max=max_val) |
| 211 | + tensor = paddle.empty(shape=shape, dtype="float32") |
| 212 | + initializer = paddle.nn.initializer.TruncatedNormal( |
| 213 | + mean=mean, std=std, a=min_val, b=max_val |
| 214 | + ) |
| 215 | + initializer(tensor) |
211 | 216 | else: |
212 | 217 | tensor = paddle.uniform( |
213 | 218 | shape=shape, dtype="float32", min=min_val, max=max_val |
|
0 commit comments