Skip to content

Commit 7e651c8

Browse files
Fix truncated norm (#13785)
* Fix truncated normal. * test=develop
1 parent 16b1beb commit 7e651c8

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

paddle/fluid/operators/truncated_gaussian_random_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ struct TruncatedNormal {
148148

149149
T operator()(T value) const {
150150
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
151-
return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std;
151+
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
152152
}
153153
};
154154

paddle/fluid/operators/truncated_gaussian_random_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct TruncatedNormal {
4242
rng.discard(n);
4343
T value = dist(rng);
4444
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
45-
return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std;
45+
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
4646
}
4747
};
4848

@@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
5252
void Compute(const framework::ExecutionContext& context) const override {
5353
auto* tensor = context.Output<framework::Tensor>("Out");
5454
T* data = tensor->mutable_data<T>(context.GetPlace());
55+
5556
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
5657
if (seed == 0) {
5758
std::random_device rd;

0 commit comments

Comments
 (0)