Skip to content

Commit db3b943

Browse files
authored
Adding Normal distribution initializer and unit tests for python initializers (#5256)
1 parent 0b76c73 commit db3b943

File tree

4 files changed

+177
-8
lines changed

4 files changed

+177
-8
lines changed

paddle/operators/gaussian_random_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
4545
void InferShape(framework::InferShapeContext* ctx) const override {
4646
PADDLE_ENFORCE(ctx->HasOutput("Out"),
4747
"Output(Out) of GaussianRandomOp should not be null.");
48-
auto dims = ctx->Attrs().Get<std::vector<int>>("dims");
48+
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
4949
std::vector<int64_t> temp;
50-
temp.reserve(dims.size());
51-
for (auto dim : dims) {
50+
temp.reserve(shape.size());
51+
for (auto dim : shape) {
5252
temp.push_back(static_cast<int64_t>(dim));
5353
}
54-
PADDLE_ENFORCE(dims.size() > 0UL,
55-
"dims can be one int or array. dims must be set.");
54+
PADDLE_ENFORCE(shape.size() > 0UL,
55+
"shape can be one int or array. shape must be set.");
5656
ctx->SetOutputDim("Out", framework::make_ddim(temp));
5757
}
5858

@@ -74,7 +74,7 @@ GaussianRandom operator.
7474
Use to initialize tensor with gaussian random generator.
7575
)DOC");
7676

77-
AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
77+
AddAttr<std::vector<int>>("shape", "The dimension of random tensor.");
7878
AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f);
7979
AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f);
8080
AddAttr<int>("seed",

python/paddle/v2/framework/initializer.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __call__(self, var, block):
6262

6363

6464
class UniformInitializer(Initializer):
65-
"""Implements for random uniform distribution initializer
65+
"""Implements the random uniform distribution initializer
6666
"""
6767

6868
def __init__(self, low=-1.0, high=1.0, seed=0):
@@ -75,6 +75,7 @@ def __init__(self, low=-1.0, high=1.0, seed=0):
7575
"""
7676
assert low is not None
7777
assert high is not None
78+
assert high >= low
7879
assert seed is not None
7980
super(UniformInitializer, self).__init__()
8081
self._low = low
@@ -107,3 +108,51 @@ def __call__(self, var, block):
107108
})
108109
var.op = op
109110
return op
111+
112+
113+
class NormalInitializer(Initializer):
114+
"""Implements the random Normal(Gaussian) distribution initializer
115+
"""
116+
117+
def __init__(self, loc=0.0, scale=1.0, seed=0):
118+
"""Constructor for NormalInitializer
119+
120+
Args:
121+
loc: mean of the normal distribution
122+
scale: standard deviation of the normal distribution
123+
seed: random seed
124+
"""
125+
assert loc is not None
126+
assert scale is not None
127+
assert seed is not None
128+
super(NormalInitializer, self).__init__()
129+
self._mean = loc
130+
self._std_dev = scale
131+
self._seed = seed
132+
133+
def __call__(self, var, block):
134+
"""Add normal distribution initialization ops for a variable
135+
136+
Args:
137+
var: Variable that needs to be initialized
138+
block: The block in which initialization ops
139+
should be added
140+
141+
Returns:
142+
the initialization op
143+
"""
144+
assert isinstance(var, framework.Variable)
145+
assert isinstance(block, framework.Block)
146+
# Initialization Ops should be prepended and not appended
147+
op = block.prepend_op(
148+
type="gaussian_random",
149+
outputs={"Out": var},
150+
attrs={
151+
"shape": var.shape,
152+
"data_type": int(var.data_type),
153+
"mean": self._mean,
154+
"std": self._std_dev,
155+
"seed": self._seed
156+
})
157+
var.op = op
158+
return op

python/paddle/v2/framework/tests/test_gaussian_random_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def gaussian_random_test(self, place):
1919
op = Operator(
2020
"gaussian_random",
2121
Out='Out',
22-
dims=[1000, 784],
22+
shape=[1000, 784],
2323
mean=.0,
2424
std=1.,
2525
seed=10)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import unittest
2+
3+
import paddle.v2.framework.framework as framework
4+
import paddle.v2.framework.initializer as initializer
5+
6+
DELTA = 0.00001
7+
8+
9+
class TestConstantInitializer(unittest.TestCase):
10+
def test_constant_initializer_default_value(self):
11+
"""Test the constant initializer with default value
12+
"""
13+
program = framework.Program()
14+
block = program.global_block()
15+
block.create_parameter(
16+
dtype="float32",
17+
shape=[5, 10],
18+
lod_level=0,
19+
name="param",
20+
initializer=initializer.ConstantInitializer())
21+
self.assertEqual(len(block.ops), 1)
22+
init_op = block.ops[0]
23+
self.assertEqual(init_op.type, 'fill_constant')
24+
self.assertAlmostEqual(init_op.attr('value'), 0.0, delta=DELTA)
25+
26+
def test_constant_initializer(self):
27+
"""Test constant initializer with supplied value
28+
"""
29+
program = framework.Program()
30+
block = program.global_block()
31+
block.create_parameter(
32+
dtype="float32",
33+
shape=[5, 10],
34+
lod_level=0,
35+
name="param",
36+
initializer=initializer.ConstantInitializer(2.3))
37+
self.assertEqual(len(block.ops), 1)
38+
init_op = block.ops[0]
39+
self.assertEqual(init_op.type, 'fill_constant')
40+
self.assertAlmostEqual(init_op.attr('value'), 2.3, delta=DELTA)
41+
42+
43+
class TestUniformInitializer(unittest.TestCase):
44+
def test_uniform_initializer_default_value(self):
45+
"""Test the uniform initializer with default value
46+
"""
47+
program = framework.Program()
48+
block = program.global_block()
49+
block.create_parameter(
50+
dtype="float32",
51+
shape=[5, 10],
52+
lod_level=0,
53+
name="param",
54+
initializer=initializer.UniformInitializer())
55+
self.assertEqual(len(block.ops), 1)
56+
init_op = block.ops[0]
57+
self.assertEqual(init_op.type, 'uniform_random')
58+
self.assertAlmostEqual(init_op.attr('min'), -1.0, delta=DELTA)
59+
self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA)
60+
self.assertEqual(init_op.attr('seed'), 0)
61+
62+
def test_uniform_initializer(self):
63+
"""Test uniform initializer with supplied attributes
64+
"""
65+
program = framework.Program()
66+
block = program.global_block()
67+
block.create_parameter(
68+
dtype="float32",
69+
shape=[5, 10],
70+
lod_level=0,
71+
name="param",
72+
initializer=initializer.UniformInitializer(-4.2, 3.1, 123))
73+
self.assertEqual(len(block.ops), 1)
74+
init_op = block.ops[0]
75+
self.assertEqual(init_op.type, 'uniform_random')
76+
self.assertAlmostEqual(init_op.attr('min'), -4.2, delta=DELTA)
77+
self.assertAlmostEqual(init_op.attr('max'), 3.1, delta=DELTA)
78+
self.assertEqual(init_op.attr('seed'), 123)
79+
80+
81+
class TestNormalInitializer(unittest.TestCase):
82+
def test_normal_initializer_default_value(self):
83+
"""Test the normal initializer with default value
84+
"""
85+
program = framework.Program()
86+
block = program.global_block()
87+
block.create_parameter(
88+
dtype="float32",
89+
shape=[5, 10],
90+
lod_level=0,
91+
name="param",
92+
initializer=initializer.NormalInitializer())
93+
self.assertEqual(len(block.ops), 1)
94+
init_op = block.ops[0]
95+
self.assertEqual(init_op.type, 'gaussian_random')
96+
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
97+
self.assertAlmostEqual(init_op.attr('std'), 1.0, delta=DELTA)
98+
self.assertEqual(init_op.attr('seed'), 0)
99+
100+
def test_normal_initializer(self):
101+
"""Test normal initializer with supplied attributes
102+
"""
103+
program = framework.Program()
104+
block = program.global_block()
105+
block.create_parameter(
106+
dtype="float32",
107+
shape=[5, 10],
108+
lod_level=0,
109+
name="param",
110+
initializer=initializer.NormalInitializer(2.3, 1.9, 123))
111+
self.assertEqual(len(block.ops), 1)
112+
init_op = block.ops[0]
113+
self.assertEqual(init_op.type, 'gaussian_random')
114+
self.assertAlmostEqual(init_op.attr('mean'), 2.3, delta=DELTA)
115+
self.assertAlmostEqual(init_op.attr('std'), 1.9, delta=DELTA)
116+
self.assertEqual(init_op.attr('seed'), 123)
117+
118+
119+
if __name__ == '__main__':
120+
unittest.main()

0 commit comments

Comments
 (0)