Skip to content

Commit 610ad49

Browse files
authored
Merge pull request #7637 from JiayiFeng/dev_global_norm_clip
Gradient clip by global norm
2 parents f45b0b0 + e8adcaf commit 610ad49

File tree

7 files changed

+173
-27
lines changed

7 files changed

+173
-27
lines changed

python/paddle/v2/fluid/clip.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
import layers
17+
import framework
1718
from . import core
1819

1920
__all__ = [
@@ -66,15 +67,15 @@ def error_clip_callback(block, context):
6667

6768

6869
class BaseGradientClipAttr(object):
69-
def process_context(self, context, p_g):
70+
def process_context(self, context, param, grad):
7071
raise NotImplementedError()
7172

7273
def create_operators(self, param, grad):
7374
raise NotImplementedError()
7475

7576

7677
class NullGradientClipAttr(BaseGradientClipAttr):
77-
def process_context(self, context, p_g):
78+
def process_context(self, context, param, grad):
7879
pass
7980

8081
def create_operators(self, param, grad):
@@ -91,27 +92,101 @@ def __init__(self, max, min=None):
9192
self.max = max
9293
self.min = min
9394

94-
def process_context(self, context, p_g):
95+
def process_context(self, context, param, grad):
9596
pass
9697

9798
def create_operators(self, param, grad):
9899
new_grad = layers.clip(x=grad, min=self.min, max=self.max)
99100
return param, new_grad
100101

101102

103+
class GradientClipByNorm(BaseGradientClipAttr):
104+
def __init__(self, clip_norm):
105+
self.clip_norm = clip_norm
106+
107+
def process_context(self, context, param, grad):
108+
pass
109+
110+
def create_operators(self, param, grad):
111+
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
112+
return param, new_grad
113+
114+
115+
class GradientClipByGlobalNorm(BaseGradientClipAttr):
116+
def __init__(self, clip_norm, group_name="default_group"):
117+
if not isinstance(group_name, basestring):
118+
raise TypeError("'group_name' must be a basestring.")
119+
120+
self.clip_norm = clip_norm
121+
self.group_name = group_name
122+
123+
def process_context(self, context, param, grad):
124+
if self.group_name not in context:
125+
context[self.group_name] = []
126+
context[self.group_name + "_clip_value"] = self.clip_norm
127+
context[self.group_name + "_clip"] = layers.fill_constant(
128+
shape=[1], dtype="float32", value=self.clip_norm)
129+
else:
130+
if not self.clip_norm == context[self.group_name + "_clip_value"]:
131+
raise ValueError(
132+
"All parameters' 'clip_norm' of a same group should be the same"
133+
)
134+
135+
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
136+
context[self.group_name].append(local_norm_var)
137+
138+
self.context = context
139+
140+
def create_operators(self, param, grad):
141+
group_scale_name = self.group_name + "_scale"
142+
if group_scale_name not in self.context:
143+
group_norm_var = layers.sums(input=self.context[self.group_name])
144+
layers.sqrt(x=group_norm_var, out=group_norm_var)
145+
clip_var = self.context[self.group_name + "_clip"]
146+
group_scale_var = layers.elementwise_div(
147+
x=clip_var,
148+
y=layers.elementwise_max(
149+
x=clip_var, y=group_norm_var))
150+
assert group_scale_var.shape == (1L, )
151+
self.context[group_scale_name] = group_scale_var
152+
153+
new_grad = layers.elementwise_mul(
154+
x=grad, y=self.context[group_scale_name])
155+
return param, new_grad
156+
157+
158+
def gradient_clip_by_global_norm(clip_norm,
159+
param_list=None,
160+
group_name="default_group",
161+
program=None):
162+
if program is None:
163+
program = framework.default_main_program()
164+
if param_list is None:
165+
param_list = program.block(0).all_parameters()
166+
if all(isinstance(elem, basestring) for elem in param_list):
167+
param_list = [program.block(0).var(elem) for elem in param_list]
168+
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
169+
raise TypeError(
170+
"'param_list' should be a list of Parameter or basestring(parameter's name)."
171+
)
172+
173+
for param in param_list:
174+
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
175+
group_name)
176+
177+
102178
def append_gradient_clip_ops(param_grad):
103179
context = dict()
104180
create_op_callbacks = []
105181
for p, g in param_grad:
106-
clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr())
182+
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
107183
if clip_attr is None:
108184
clip_attr = NullGradientClipAttr()
109185
if not isinstance(clip_attr, BaseGradientClipAttr):
110186
raise TypeError(
111-
"clip attribute should be an instance of BaseGradientClippingAttr"
112-
)
187+
"clip attribute should be an instance of BaseGradientClipAttr")
113188

114-
clip_attr.process_context(context=context, p_g=param_grad)
189+
clip_attr.process_context(context=context, param=p, grad=g)
115190
create_op_callbacks.append(
116191
functools.partial(
117192
clip_attr.create_operators, param=p, grad=g))

python/paddle/v2/fluid/framework.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def copy_param_info_from(self, other):
780780
trainable=p.trainable,
781781
optimize_attr=p.optimize_attr,
782782
regularizer=p.regularizer,
783-
clip_attr=p.clip_attr,
783+
gradient_clip_attr=p.gradient_clip_attr,
784784
error_clip=p.error_clip,
785785
name=v.name)
786786
self.vars[new_p.name] = new_p
@@ -948,7 +948,7 @@ def __init__(self, block, shape, dtype, **kwargs):
948948

949949
self.regularizer = kwargs.get('regularizer', None)
950950

951-
self.clip_attr = kwargs.get('clip_attr', None)
951+
self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
952952

953953

954954
# program is a global instance.

python/paddle/v2/fluid/layers/ops.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,10 @@
4646
]
4747

4848
__all__ = [
49-
'mean',
50-
'mul',
51-
'reshape',
52-
'scale',
53-
'transpose',
54-
'sigmoid_cross_entropy_with_logits',
55-
'elementwise_add',
56-
'elementwise_div',
57-
'elementwise_sub',
58-
'elementwise_mul',
59-
'elementwise_max',
60-
'elementwise_min',
61-
'clip',
62-
'sequence_softmax',
49+
'mean', 'mul', 'reshape', 'scale', 'transpose',
50+
'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
51+
'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min',
52+
'clip', 'clip_by_norm', 'sequence_softmax'
6353
] + __activations__
6454

6555
for _OP in set(__all__):

python/paddle/v2/fluid/param_attr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ def __init__(self,
2525
learning_rate=1.0,
2626
regularizer=None,
2727
trainable=True,
28-
clip=None):
28+
gradient_clip=None):
2929
self.name = name
3030
self.initializer = initializer
3131
self.learning_rate = learning_rate
3232
self.regularizer = regularizer
3333
self.trainable = trainable
34-
self.clip = clip
34+
self.gradient_clip = gradient_clip
3535

3636
def set_default_initializer(self, initializer):
3737
if initializer is None:
@@ -77,7 +77,7 @@ def to_kwargs(self, with_initializer=False):
7777
},
7878
'regularizer': self.regularizer,
7979
'trainable': self.trainable,
80-
'clip_attr': self.clip
80+
'gradient_clip_attr': self.gradient_clip
8181
}
8282
if with_initializer:
8383
kwargs['initializer'] = self.initializer

python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
act='relu',
2828
param_attr=fluid.ParamAttr(
2929
regularizer=regularizer,
30-
clip=fluid.clip.ClipByValue(10)))
30+
gradient_clip=fluid.clip.ClipByValue(10)))
3131

3232
hidden2 = fluid.layers.fc(input=hidden1,
3333
size=64,
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle.v2 as paddle
17+
import paddle.v2.fluid as fluid
18+
19+
BATCH_SIZE = 128
20+
CLIP = 1
21+
22+
prog = fluid.framework.Program()
23+
with fluid.program_guard(main_program=prog):
24+
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
25+
26+
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
27+
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
28+
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
29+
30+
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
31+
32+
cost = fluid.layers.cross_entropy(input=predict, label=label)
33+
avg_cost = fluid.layers.mean(x=cost)
34+
35+
prog_clip = prog.clone()
36+
37+
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
38+
39+
p_g = fluid.backward.append_backward(loss=avg_cost)
40+
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
41+
42+
with fluid.program_guard(main_program=prog_clip):
43+
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
44+
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
45+
46+
grad_list = [elem[1] for elem in p_g]
47+
grad_clip_list = [elem[1] for elem in p_g_clip]
48+
49+
train_reader = paddle.batch(
50+
paddle.reader.shuffle(
51+
paddle.dataset.mnist.train(), buf_size=8192),
52+
batch_size=BATCH_SIZE)
53+
54+
place = fluid.CPUPlace()
55+
exe = fluid.Executor(place)
56+
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
57+
exe.run(fluid.default_startup_program())
58+
59+
count = 0
60+
for data in train_reader():
61+
count += 1
62+
if count > 5:
63+
break
64+
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
65+
out_clip = exe.run(prog_clip,
66+
feed=feeder.feed(data),
67+
fetch_list=grad_clip_list)
68+
global_norm = 0
69+
for v in out[1:]:
70+
global_norm += np.sum(np.power(v, 2))
71+
global_norm = np.sqrt(global_norm)
72+
73+
global_norm_clip = 0
74+
for v in out_clip[1:]:
75+
global_norm_clip += np.sum(np.power(v, 2))
76+
global_norm_clip = np.sqrt(global_norm_clip)
77+
78+
if not np.isclose(
79+
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3):
80+
exit(1)
81+
exit(0)

0 commit comments

Comments
 (0)