Skip to content

Commit 26a94b3

Browse files
committed
Fix truncated normal.
test=release/1.0.0
1 parent b2e6e5f commit 26a94b3

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

paddle/fluid/operators/truncated_gaussian_random_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ struct TruncatedNormal {
148148

149149
T operator()(T value) const {
150150
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
151-
return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std;
151+
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
152152
}
153153
};
154154

paddle/fluid/operators/truncated_gaussian_random_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct TruncatedNormal {
4242
rng.discard(n);
4343
T value = dist(rng);
4444
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
45-
return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std;
45+
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
4646
}
4747
};
4848

0 commit comments

Comments
 (0)