Skip to content

Commit cf12823

Browse files
Add truncated gaussian initializer. (#13000)
* Add truncated gaussian initializer. * Fix unitest. * Update API.spec * Fix code style and fix bug. * Fix code style. * Small fix.
1 parent 642cf6c commit cf12823

File tree

6 files changed

+471
-4
lines changed

6 files changed

+471
-4
lines changed

doc/fluid/api/initializer.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ Normal
3232
:members:
3333
:noindex:
3434

35+
.. _api_fluid_initializer_Normal:
36+
37+
TruncatedNormal
38+
------
39+
40+
.. autoclass:: paddle.fluid.initializer.TruncatedNormal
41+
:members:
42+
:noindex:
43+
3544
.. _api_fluid_initializer_Xavier:
3645

3746
Xavier

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program
7979
paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
8080
paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
8181
paddle.fluid.initializer.NormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0))
82+
paddle.fluid.initializer.TruncatedNormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0))
8283
paddle.fluid.initializer.XavierInitializer.__init__ ArgSpec(args=['self', 'uniform', 'fan_in', 'fan_out', 'seed'], varargs=None, keywords=None, defaults=(True, None, None, 0))
8384
paddle.fluid.initializer.BilinearInitializer.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
8485
paddle.fluid.initializer.MSRAInitializer.__init__ ArgSpec(args=['self', 'uniform', 'fan_in', 'seed'], varargs=None, keywords=None, defaults=(True, None, 0))
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <limits>
16+
#include <random>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e
23+
template <typename T>
24+
T Erfinv(T x) {
25+
if (x < -1 || x > 1) {
26+
return std::numeric_limits<T>::quiet_NaN();
27+
} else if (x == 1.0) {
28+
return std::numeric_limits<T>::infinity();
29+
} else if (x == -1.0) {
30+
return -std::numeric_limits<T>::infinity();
31+
}
32+
33+
const T LN2 = 6.931471805599453094172321214581e-1;
34+
35+
const T A0 = 1.1975323115670912564578e0;
36+
const T A1 = 4.7072688112383978012285e1;
37+
const T A2 = 6.9706266534389598238465e2;
38+
const T A3 = 4.8548868893843886794648e3;
39+
const T A4 = 1.6235862515167575384252e4;
40+
const T A5 = 2.3782041382114385731252e4;
41+
const T A6 = 1.1819493347062294404278e4;
42+
const T A7 = 8.8709406962545514830200e2;
43+
44+
const T B0 = 1.0000000000000000000e0;
45+
const T B1 = 4.2313330701600911252e1;
46+
const T B2 = 6.8718700749205790830e2;
47+
const T B3 = 5.3941960214247511077e3;
48+
const T B4 = 2.1213794301586595867e4;
49+
const T B5 = 3.9307895800092710610e4;
50+
const T B6 = 2.8729085735721942674e4;
51+
const T B7 = 5.2264952788528545610e3;
52+
53+
const T C0 = 1.42343711074968357734e0;
54+
const T C1 = 4.63033784615654529590e0;
55+
const T C2 = 5.76949722146069140550e0;
56+
const T C3 = 3.64784832476320460504e0;
57+
const T C4 = 1.27045825245236838258e0;
58+
const T C5 = 2.41780725177450611770e-1;
59+
const T C6 = 2.27238449892691845833e-2;
60+
const T C7 = 7.74545014278341407640e-4;
61+
62+
const T D0 = 1.4142135623730950488016887e0;
63+
const T D1 = 2.9036514445419946173133295e0;
64+
const T D2 = 2.3707661626024532365971225e0;
65+
const T D3 = 9.7547832001787427186894837e-1;
66+
const T D4 = 2.0945065210512749128288442e-1;
67+
const T D5 = 2.1494160384252876777097297e-2;
68+
const T D6 = 7.7441459065157709165577218e-4;
69+
const T D7 = 1.4859850019840355905497876e-9;
70+
71+
const T E0 = 6.65790464350110377720e0;
72+
const T E1 = 5.46378491116411436990e0;
73+
const T E2 = 1.78482653991729133580e0;
74+
const T E3 = 2.96560571828504891230e-1;
75+
const T E4 = 2.65321895265761230930e-2;
76+
const T E5 = 1.24266094738807843860e-3;
77+
const T E6 = 2.71155556874348757815e-5;
78+
const T E7 = 2.01033439929228813265e-7;
79+
80+
const T F0 = 1.414213562373095048801689e0;
81+
const T F1 = 8.482908416595164588112026e-1;
82+
const T F2 = 1.936480946950659106176712e-1;
83+
const T F3 = 2.103693768272068968719679e-2;
84+
const T F4 = 1.112800997078859844711555e-3;
85+
const T F5 = 2.611088405080593625138020e-5;
86+
const T F6 = 2.010321207683943062279931e-7;
87+
const T F7 = 2.891024605872965461538222e-15;
88+
89+
T abs_x = abs(x);
90+
91+
if (abs_x <= 0.85) {
92+
T r = 0.180625 - 0.25 * x * x;
93+
T num =
94+
(((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) *
95+
r +
96+
A0);
97+
T den =
98+
(((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) *
99+
r +
100+
B0);
101+
return x * num / den;
102+
}
103+
104+
T r = sqrt(LN2 - log(1.0 - abs_x));
105+
106+
T num, den;
107+
if (r <= 5.0) {
108+
r = r - 1.6;
109+
num =
110+
(((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) *
111+
r +
112+
C0);
113+
den =
114+
(((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) *
115+
r +
116+
D0);
117+
} else {
118+
r = r - 5.0;
119+
num =
120+
(((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) *
121+
r +
122+
E0);
123+
den =
124+
(((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) *
125+
r +
126+
F0);
127+
}
128+
129+
if (x < 0) {
130+
return -num / den;
131+
} else {
132+
return num / den;
133+
}
134+
}
135+
136+
template <typename T>
137+
struct TruncatedNormal {
138+
T mean, std;
139+
T a_normal_cdf;
140+
T b_normal_cdf;
141+
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
142+
auto normal_cdf = [](T x) {
143+
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
144+
};
145+
a_normal_cdf = normal_cdf(-2.0);
146+
b_normal_cdf = normal_cdf(2.0);
147+
}
148+
149+
T operator()(T value) const {
150+
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
151+
return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std;
152+
}
153+
};
154+
155+
template <typename T>
156+
class CPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
157+
public:
158+
void Compute(const framework::ExecutionContext& context) const override {
159+
float mean = context.Attr<float>("mean");
160+
float std = context.Attr<float>("std");
161+
auto* tensor = context.Output<framework::Tensor>("Out");
162+
T* data = tensor->mutable_data<T>(context.GetPlace());
163+
164+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
165+
std::minstd_rand engine;
166+
if (seed == 0) {
167+
seed = std::random_device()();
168+
}
169+
engine.seed(seed);
170+
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
171+
1.0);
172+
TruncatedNormal<T> truncated_normal(mean, std);
173+
int64_t size = tensor->numel();
174+
for (int64_t i = 0; i < size; ++i) {
175+
data[i] = truncated_normal(dist(engine));
176+
}
177+
}
178+
};
179+
180+
class TruncatedGaussianRandomOp : public framework::OperatorWithKernel {
181+
public:
182+
using framework::OperatorWithKernel::OperatorWithKernel;
183+
184+
void InferShape(framework::InferShapeContext* ctx) const override {
185+
PADDLE_ENFORCE(
186+
ctx->HasOutput("Out"),
187+
"Output(Out) of TruncatedGaussianRandomOp should not be null.");
188+
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
189+
std::vector<int64_t> out_dim;
190+
out_dim.reserve(shape.size());
191+
for (auto dim : shape) {
192+
out_dim.push_back(static_cast<int64_t>(dim));
193+
}
194+
PADDLE_ENFORCE(shape.size() > 0UL,
195+
"shape can be one int or array. shape must be set.");
196+
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
197+
}
198+
199+
protected:
200+
framework::OpKernelType GetExpectedKernelType(
201+
const framework::ExecutionContext& ctx) const override {
202+
framework::LibraryType library{framework::LibraryType::kPlain};
203+
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
204+
return framework::OpKernelType(
205+
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
206+
ctx.device_context(), layout, library);
207+
}
208+
};
209+
210+
class TruncatedGaussianRandomOpMaker
211+
: public framework::OpProtoAndCheckerMaker {
212+
public:
213+
void Make() override {
214+
AddOutput("Out", "Output tensor of truncated gaussian random op.");
215+
216+
AddAttr<std::vector<int>>("shape",
217+
"(vector<int>) "
218+
"The dimension of random tensor.");
219+
AddAttr<float>("mean",
220+
"(float, default 0.0) "
221+
"mean of random tensor.")
222+
.SetDefault(.0f);
223+
AddAttr<float>("std",
224+
"(float, default 1.0) "
225+
"std of random tensor.")
226+
.SetDefault(1.0f);
227+
AddAttr<int>("seed",
228+
"(int, default 0) "
229+
"Random seed of generator."
230+
"0 means use system wide seed."
231+
"Note that if seed is not 0, this operator will always "
232+
"generate the same random numbers every time.")
233+
.SetDefault(0);
234+
AddAttr<int>("dtype",
235+
"(int, default 5(FP32)) "
236+
"Output data type.")
237+
.SetDefault(framework::proto::VarType::FP32);
238+
AddComment(R"DOC(
239+
TruncatedGaussianRandom Operator.
240+
241+
Used to initialize tensors with truncated gaussian random generator.
242+
243+
)DOC");
244+
}
245+
};
246+
247+
} // namespace operators
248+
} // namespace paddle
249+
250+
namespace ops = paddle::operators;
251+
REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random,
252+
ops::TruncatedGaussianRandomOp,
253+
ops::TruncatedGaussianRandomOpMaker);
254+
REGISTER_OP_CPU_KERNEL(truncated_gaussian_random,
255+
ops::CPUTruncatedGaussianRandomKernel<float>);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <thrust/random.h>
16+
#include <thrust/transform.h>
17+
#include <limits>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/framework/operator.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename T>
25+
struct TruncatedNormal {
26+
T mean, std;
27+
T a_normal_cdf;
28+
T b_normal_cdf;
29+
unsigned int seed;
30+
T numeric_min;
31+
32+
__host__ __device__ TruncatedNormal(T mean, T std, T numeric_min, int seed)
33+
: mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
34+
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
35+
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
36+
}
37+
38+
__host__ __device__ T operator()(const unsigned int n) const {
39+
thrust::minstd_rand rng;
40+
rng.seed(seed);
41+
thrust::uniform_real_distribution<T> dist(numeric_min, 1);
42+
rng.discard(n);
43+
T value = dist(rng);
44+
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
45+
return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std;
46+
}
47+
};
48+
49+
template <typename T>
50+
class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
51+
public:
52+
void Compute(const framework::ExecutionContext& context) const override {
53+
auto* tensor = context.Output<framework::Tensor>("Out");
54+
T* data = tensor->mutable_data<T>(context.GetPlace());
55+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
56+
if (seed == 0) {
57+
std::random_device rd;
58+
seed = rd();
59+
}
60+
T mean = static_cast<T>(context.Attr<float>("mean"));
61+
T std = static_cast<T>(context.Attr<float>("std"));
62+
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
63+
int64_t size = tensor->numel();
64+
thrust::transform(
65+
index_sequence_begin, index_sequence_begin + size,
66+
thrust::device_ptr<T>(data),
67+
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
68+
}
69+
};
70+
71+
} // namespace operators
72+
} // namespace paddle
73+
74+
REGISTER_OP_CUDA_KERNEL(
75+
truncated_gaussian_random,
76+
paddle::operators::GPUTruncatedGaussianRandomKernel<float>);

0 commit comments

Comments
 (0)