Skip to content

Commit 61eafbe

Browse files
authored
Adding a framework for variable initializers (#5232)
1 parent 9b70b6a commit 61eafbe

File tree

5 files changed

+128
-55
lines changed

5 files changed

+128
-55
lines changed

python/paddle/v2/framework/framework.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ def all_parameters(self):
354354

355355
def create_var(self, *args, **kwargs):
356356
var = Variable(self, *args, **kwargs)
357-
if 'init_attr' in kwargs:
358-
self._prepend_initialize_ops_(var, kwargs['init_attr'])
357+
if 'initializer' in kwargs:
358+
kwargs['initializer'](var, self)
359359
return var
360360

361361
def has_var(self, name):
@@ -364,8 +364,8 @@ def has_var(self, name):
364364
def create_parameter(self, *args, **kwargs):
365365
global_block = self.program.global_block()
366366
param = Parameter(global_block, *args, **kwargs)
367-
if 'init_attr' in kwargs:
368-
self._prepend_initialize_ops_(param, kwargs['init_attr'])
367+
if 'initializer' in kwargs:
368+
kwargs['initializer'](param, self)
369369
return param
370370

371371
def append_op(self, *args, **kwargs):
@@ -424,17 +424,6 @@ def sync_with_cpp(self):
424424
for index in range(len(self.ops)):
425425
assert self.ops[index].desc == ops_in_cpp[index]
426426

427-
def _prepend_initialize_ops_(self, param, init_attr):
428-
op_type = init_attr['type']
429-
init_attr['shape'] = param.shape
430-
init_attr['data_type'] = int(param.data_type)
431-
op = self.prepend_op(
432-
type=op_type,
433-
inputs=None,
434-
outputs={'Out': [param]},
435-
attrs=init_attr)
436-
param.op = op
437-
438427

439428
class Program(object):
440429
def __init__(self):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import paddle.v2.framework.framework as framework
2+
3+
__all__ = ['ConstantInitializer', 'UniformInitializer']
4+
5+
6+
class Initializer(object):
7+
"""Base class for variable initializers
8+
9+
Defines the common interface of variable initializers.
10+
They add operations to the init program that are used
11+
to initialize variables. Users should not use this class
12+
directly, but need to use one of its implementations.
13+
"""
14+
15+
def __init_(self):
16+
pass
17+
18+
def __call__(self, param, block):
19+
"""Add corresponding initialization operations to the network
20+
"""
21+
raise NotImplementedError()
22+
23+
24+
class ConstantInitializer(Initializer):
25+
"""Implements the constant initializer
26+
"""
27+
28+
def __init__(self, value=0.0):
29+
"""Constructor for ConstantInitializer
30+
31+
Args:
32+
value: constant value to initialize the variable
33+
"""
34+
assert value is not None
35+
super(ConstantInitializer, self).__init__()
36+
self._value = value
37+
38+
def __call__(self, var, block):
39+
"""Add constant initialization ops for a variable
40+
41+
Args:
42+
var: Variable that needs to be initialized
43+
block: The block in which initialization ops
44+
should be added
45+
46+
Returns:
47+
the initialization op
48+
"""
49+
assert isinstance(var, framework.Variable)
50+
assert isinstance(block, framework.Block)
51+
# Initialization Ops should be prepended and not appended
52+
op = block.prepend_op(
53+
type="fill_constant",
54+
outputs={"Out": var},
55+
attrs={
56+
"shape": var.shape,
57+
"data_type": int(var.data_type),
58+
"value": self._value
59+
})
60+
var.op = op
61+
return op
62+
63+
64+
class UniformInitializer(Initializer):
65+
"""Implements for random uniform distribution initializer
66+
"""
67+
68+
def __init__(self, low=-1.0, high=1.0, seed=0):
69+
"""Constructor for UniformInitializer
70+
71+
Args:
72+
low: lower boundary of the uniform distribution
73+
high: upper boundary of the uniform distribution
74+
seed: random seed
75+
"""
76+
assert low is not None
77+
assert high is not None
78+
assert seed is not None
79+
super(UniformInitializer, self).__init__()
80+
self._low = low
81+
self._high = high
82+
self._seed = seed
83+
84+
def __call__(self, var, block):
85+
"""Add uniform distribution initialization ops for a variable
86+
87+
Args:
88+
var: Variable that needs to be initialized
89+
block: The block in which initialization ops
90+
should be added
91+
92+
Returns:
93+
the initialization op
94+
"""
95+
assert isinstance(var, framework.Variable)
96+
assert isinstance(block, framework.Block)
97+
# Initialization Ops should be prepended and not appended
98+
op = block.prepend_op(
99+
type="uniform_random",
100+
outputs={"Out": var},
101+
attrs={
102+
"shape": var.shape,
103+
"data_type": int(var.data_type),
104+
"min": self._low,
105+
"max": self._high,
106+
"seed": self._seed
107+
})
108+
var.op = op
109+
return op

python/paddle/v2/framework/layer_helper.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from paddle.v2.framework.framework import Variable, g_program, \
77
g_init_program
8+
from paddle.v2.framework.initializer import ConstantInitializer, \
9+
UniformInitializer
810

911

1012
def unique_name(prefix):
@@ -66,14 +68,7 @@ def input(self, input_param_name='input'):
6668

6769
@property
6870
def param_attr(self):
69-
default = {
70-
'name': None,
71-
'init_attr': {
72-
'type': 'uniform_random',
73-
'min': -1.0,
74-
'max': 1.0
75-
}
76-
}
71+
default = {'name': None, 'initializer': UniformInitializer()}
7772
actual = self.kwargs.get('param_attr', None)
7873
if actual is None:
7974
actual = default
@@ -83,13 +78,7 @@ def param_attr(self):
8378
return actual
8479

8580
def bias_attr(self):
86-
default = {
87-
'name': None,
88-
'init_attr': {
89-
'type': 'fill_constant',
90-
'value': 0.0
91-
}
92-
}
81+
default = {'name': None, 'initializer': ConstantInitializer()}
9382
bias_attr = self.kwargs.get('bias_attr', None)
9483
if bias_attr is True:
9584
bias_attr = default

python/paddle/v2/framework/layers.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
22
import paddle.v2.framework.core as core
33
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
4+
from paddle.v2.framework.initializer import ConstantInitializer
45
import re
56

67
__all__ = [
@@ -440,26 +441,12 @@ def batch_norm(input,
440441
else:
441442
raise ValueError("unsupported data layout:" + data_layout)
442443

443-
def get_init_attr(value):
444-
if not isinstance(value, float):
445-
raise ValueError("attr value should be a float")
446-
return {'type': 'fill_constant', 'value': value}
447-
448-
def prepend_init_op(var, init_attr):
449-
assert isinstance(var, Variable)
450-
op_type = init_attr['type']
451-
init_attr['shape'] = var.shape
452-
init_attr['data_type'] = int(var.data_type)
453-
op = var.block.prepend_op(
454-
type=op_type, inputs=None, outputs={'Out': [var]}, attrs=init_attr)
455-
return op
456-
457-
def create_persistable_var(dtype, shape, init_attr=None):
444+
def create_persistable_var(dtype, shape, initializer=None):
458445
name = unique_name(".".join([helper.name, "xxxx"]))
459446
var = init_program.global_block().create_var(
460447
dtype=dtype, shape=shape, name=name, persistable=True)
461-
if 'init_attr' is not None:
462-
prepend_init_op(var, init_attr)
448+
if initializer is not None:
449+
initializer(var, var.block)
463450
return program.global_block().create_var(
464451
name=name, dtype=dtype, shape=shape, persistable=True)
465452

@@ -472,8 +459,9 @@ def create_persistable_var(dtype, shape, init_attr=None):
472459
attr=helper.param_attr, shape=param_shape, dtype=dtype)
473460

474461
# create input
475-
mean = create_persistable_var(dtype, param_shape, get_init_attr(0.0))
476-
variance = create_persistable_var(dtype, param_shape, get_init_attr(1.0))
462+
mean = create_persistable_var(dtype, param_shape, ConstantInitializer(0.0))
463+
variance = create_persistable_var(dtype, param_shape,
464+
ConstantInitializer(1.0))
477465

478466
# create output
479467
# mean and mean_out share the same memory

python/paddle/v2/framework/tests/test_recognize_digits_mlp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import paddle.v2.framework.core as core
44
import paddle.v2.framework.optimizer as optimizer
55

6-
from paddle.v2.framework.framework import Program, g_program
6+
from paddle.v2.framework.framework import Program
77
from paddle.v2.framework.executor import Executor
88
from paddle.v2.framework.regularizer import L2DecayRegularizer
9+
from paddle.v2.framework.initializer import UniformInitializer
910

1011
import numpy as np
1112

@@ -21,11 +22,8 @@
2122

2223
param_attr = {
2324
'name': None,
24-
'init_attr': {
25-
'type': 'uniform_random',
26-
'min': -1.0,
27-
'max': 1.0
28-
},
25+
'initializer': UniformInitializer(
26+
low=-1.0, high=1.0),
2927
'regularization': L2DecayRegularizer(0.0005 * BATCH_SIZE)
3028
}
3129

0 commit comments

Comments
 (0)