Skip to content

Commit 42b0748

Browse files
committed
add unittest
1 parent 974183b commit 42b0748

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

python/paddle/v2/fluid/clip.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def create_operators(self, param, grad):
113113

114114
class GradientClipByGlobalNorm(BaseGradientClipAttr):
115115
global_norm_var = None
116+
local_norm_var = None
116117
clip_norm_var = None
117118
scale_var = None
118119

@@ -123,12 +124,18 @@ def init(cls, clip_norm):
123124

124125
cls.global_norm_var = layers.fill_constant(
125126
shape=[1], dtype="float32", value=0.0)
127+
cls.local_norm_var = framework.default_main_program().current_block(
128+
).create_var(
129+
name=framework.unique_name("local_norm"),
130+
dtype="float32",
131+
persistable=False)
126132
cls.clip_norm_var = layers.fill_constant(
127133
shape=[1], dtype="float32", value=clip_norm)
128134

129135
@classmethod
130136
def check_init(cls):
131137
if not (isinstance(cls.global_norm_var, framework.Variable) and
138+
isinstance(cls.local_norm_var, framework.Variable) and
132139
isinstance(cls.clip_norm_var, framework.Variable)):
133140
raise ValueError(
134141
"Class 'GradientClipByGlobalNorm' has not been properly initialized. \
@@ -138,17 +145,18 @@ def process_context(self, context, param, grad):
138145
cls = self.__class__
139146
cls.check_init()
140147

141-
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
148+
cls.local_norm_var = layers.reduce_sum(
149+
input=layers.pow(x=grad, factor=2.0))
142150
layers.sums(
143-
input=[local_norm_var, cls.global_norm_var],
151+
input=[cls.local_norm_var, cls.global_norm_var],
144152
out=[cls.global_norm_var])
145153

146154
def create_operators(self, param, grad):
147155
cls = self.__class__
148156
cls.check_init()
149157

150158
if cls.scale_var is None:
151-
cls.global_norm_var = layers.sqrt(x=cls.global_norm_var)
159+
layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var)
152160
cls.scale_var = layers.elementwise_div(
153161
x=cls.clip_norm_var,
154162
y=layers.elementwise_max(
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import paddle.v2 as paddle
16+
import paddle.v2.fluid as fluid
17+
18+
19+
def _get_global_param_norm_(params_grads):
20+
res = fluid.layers.fill_constant(shape=[1], dtype="float32", value=0.0)
21+
for _, grad in params_grads:
22+
norm_var = fluid.layers.reduce_sum(
23+
input=fluid.layers.pow(x=grad, factor=2.0))
24+
fluid.layers.sums(input=[norm_var, res], out=[res])
25+
fluid.layers.sqrt(x=res, out=res)
26+
return res
27+
28+
29+
BATCH_SIZE = 128
30+
CLIP = 0.5
31+
prog = fluid.framework.Program()
32+
33+
with fluid.program_guard(main_program=prog):
34+
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
35+
36+
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
37+
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
38+
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
39+
40+
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
41+
42+
cost = fluid.layers.cross_entropy(input=predict, label=label)
43+
avg_cost = fluid.layers.mean(x=cost)
44+
45+
prog_clip = prog.clone()
46+
47+
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
48+
49+
p_g = fluid.backward.append_backward(loss=avg_cost)
50+
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
51+
52+
with fluid.program_guard(main_program=prog):
53+
gloabl_norm = _get_global_param_norm_(p_g)
54+
55+
with fluid.program_guard(main_program=prog_clip):
56+
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
57+
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
58+
gloabl_norm_clip = _get_global_param_norm_(p_g_clip)
59+
60+
train_reader = paddle.batch(
61+
paddle.reader.shuffle(
62+
paddle.dataset.mnist.train(), buf_size=8192),
63+
batch_size=BATCH_SIZE)
64+
65+
place = fluid.CPUPlace()
66+
exe = fluid.Executor(place)
67+
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
68+
exe.run(fluid.default_startup_program())
69+
70+
count = 0
71+
for data in train_reader():
72+
count += 1
73+
if count > 5:
74+
break
75+
out, = exe.run(prog, feed=feeder.feed(data), fetch_list=[gloabl_norm])
76+
out_clip, = exe.run(prog_clip,
77+
feed=feeder.feed(data),
78+
fetch_list=[gloabl_norm_clip])
79+
80+
if not np.allclose(out_clip, np.minimum(out, np.array([CLIP]))):
81+
exit(1)
82+
exit(0)

0 commit comments

Comments
 (0)