Skip to content

Commit 19c554f

Browse files
committed
update
1 parent 538f1ad commit 19c554f

File tree

2 files changed

+59
-67
lines changed

2 files changed

+59
-67
lines changed

python/paddle/v2/fluid/clip.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -112,58 +112,52 @@ def create_operators(self, param, grad):
112112

113113

114114
class GradientClipByGlobalNorm(BaseGradientClipAttr):
115-
global_norm_var = None
116-
local_norm_var = None
117-
clip_norm_var = None
118-
scale_var = None
119-
120-
@classmethod
121-
def init(cls, clip_norm):
122-
if not (isinstance(clip_norm, int) or isinstance(clip_norm, float)):
123-
raise TypeError("The 'clip_norm' must be a value of int or float")
124-
125-
cls.global_norm_var = layers.fill_constant(
126-
shape=[1], dtype="float32", value=0.0)
127-
cls.local_norm_var = layers.create_tensor(dtype="float32")
128-
cls.clip_norm_var = layers.fill_constant(
129-
shape=[1], dtype="float32", value=clip_norm)
130-
131-
@classmethod
132-
def check_init(cls):
133-
if not (isinstance(cls.global_norm_var, framework.Variable) and
134-
isinstance(cls.local_norm_var, framework.Variable) and
135-
isinstance(cls.clip_norm_var, framework.Variable)):
136-
raise ValueError(
137-
"Class 'GradientClipByGlobalNorm' has not been properly initialized. \
138-
Please call GradientClipByGlobalNorm.init() first.")
115+
def __init__(self, clip_norm, group_name="default_group"):
116+
if not isinstance(group_name, basestring):
117+
raise TypeError("'group_name' must be a basestring.")
118+
119+
self.clip_norm = clip_norm
120+
self.group_name = group_name
139121

140122
def process_context(self, context, param, grad):
141-
cls = self.__class__
142-
cls.check_init()
123+
if self.group_name not in context:
124+
context[self.group_name] = []
125+
context[self.group_name + "_clip_value"] = self.clip_norm
126+
context[self.group_name + "_clip"] = layers.fill_constant(
127+
shape=[1], dtype="float32", value=self.clip_norm)
128+
else:
129+
if not self.clip_norm == context[self.group_name + "_clip_value"]:
130+
raise ValueError(
131+
"All parameters' 'clip_norm' of a same group should be the same"
132+
)
143133

144-
cls.local_norm_var = layers.reduce_sum(
145-
input=layers.pow(x=grad, factor=2.0))
146-
layers.sums(
147-
input=[cls.local_norm_var, cls.global_norm_var],
148-
out=[cls.global_norm_var])
134+
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
135+
context[self.group_name].append(local_norm_var)
149136

150-
def create_operators(self, param, grad):
151-
cls = self.__class__
152-
cls.check_init()
137+
self.context = context
153138

154-
if cls.scale_var is None:
155-
layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var)
156-
cls.scale_var = layers.elementwise_div(
157-
x=cls.clip_norm_var,
139+
def create_operators(self, param, grad):
140+
group_scale_name = self.group_name + "_scale"
141+
if group_scale_name not in self.context:
142+
group_norm_var = layers.sums(input=self.context[self.group_name])
143+
layers.sqrt(x=group_norm_var, out=group_norm_var)
144+
clip_var = self.context[self.group_name + "_clip"]
145+
group_scale_var = layers.elementwise_div(
146+
x=clip_var,
158147
y=layers.elementwise_max(
159-
x=cls.clip_norm_var, y=cls.global_norm_var))
160-
assert cls.scale_var.shape == (1L, )
148+
x=clip_var, y=group_norm_var))
149+
assert group_scale_var.shape == (1L, )
150+
self.context[group_scale_name] = group_scale_var
161151

162-
new_grad = layers.elementwise_mul(x=grad, y=cls.scale_var)
152+
new_grad = layers.elementwise_mul(
153+
x=grad, y=self.context[group_scale_name])
163154
return param, new_grad
164155

165156

166-
def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
157+
def gradient_clip_by_global_norm(clip_norm,
158+
param_list=None,
159+
group_name="default_group",
160+
program=None):
167161
if program is None:
168162
program = framework.default_main_program()
169163
if param_list is None:
@@ -175,9 +169,9 @@ def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
175169
"'param_list' should be a list of Parameter or basestring(parameter's name)."
176170
)
177171

178-
GradientClipByGlobalNorm.init(clip_norm)
179172
for param in param_list:
180-
param.gradient_clip_attr = GradientClipByGlobalNorm()
173+
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
174+
group_name)
181175

182176

183177
def append_gradient_clip_ops(param_grad):

python/paddle/v2/fluid/tests/test_gradient_clip.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,10 @@
1515
import paddle.v2 as paddle
1616
import paddle.v2.fluid as fluid
1717

18-
19-
def _get_global_param_norm_(params_grads):
20-
res = fluid.layers.fill_constant(shape=[1], dtype="float32", value=0.0)
21-
for _, grad in params_grads:
22-
norm_var = fluid.layers.reduce_sum(
23-
input=fluid.layers.pow(x=grad, factor=2.0))
24-
fluid.layers.sums(input=[norm_var, res], out=[res])
25-
fluid.layers.sqrt(x=res, out=res)
26-
return res
27-
28-
2918
BATCH_SIZE = 128
30-
CLIP = 0.5
31-
prog = fluid.framework.Program()
19+
CLIP = 1
3220

21+
prog = fluid.framework.Program()
3322
with fluid.program_guard(main_program=prog):
3423
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
3524

@@ -49,13 +38,12 @@ def _get_global_param_norm_(params_grads):
4938
p_g = fluid.backward.append_backward(loss=avg_cost)
5039
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
5140

52-
with fluid.program_guard(main_program=prog):
53-
gloabl_norm = _get_global_param_norm_(p_g)
54-
5541
with fluid.program_guard(main_program=prog_clip):
5642
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
5743
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
58-
gloabl_norm_clip = _get_global_param_norm_(p_g_clip)
44+
45+
grad_list = [elem[1] for elem in p_g]
46+
grad_clip_list = [elem[1] for elem in p_g_clip]
5947

6048
train_reader = paddle.batch(
6149
paddle.reader.shuffle(
@@ -72,11 +60,21 @@ def _get_global_param_norm_(params_grads):
7260
count += 1
7361
if count > 5:
7462
break
75-
out, = exe.run(prog, feed=feeder.feed(data), fetch_list=[gloabl_norm])
76-
out_clip, = exe.run(prog_clip,
77-
feed=feeder.feed(data),
78-
fetch_list=[gloabl_norm_clip])
79-
80-
if not np.allclose(out_clip, np.minimum(out, np.array([CLIP]))):
63+
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
64+
out_clip = exe.run(prog_clip,
65+
feed=feeder.feed(data),
66+
fetch_list=grad_clip_list)
67+
global_norm = 0
68+
for v in out[1:]:
69+
global_norm += np.sum(np.power(v, 2))
70+
global_norm = np.sqrt(global_norm)
71+
72+
global_norm_clip = 0
73+
for v in out_clip[1:]:
74+
global_norm_clip += np.sum(np.power(v, 2))
75+
global_norm_clip = np.sqrt(global_norm_clip)
76+
77+
if not np.isclose(
78+
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3):
8179
exit(1)
8280
exit(0)

0 commit comments

Comments
 (0)