Skip to content

Commit a173fa7

Browse files
authored
Merge pull request #7732 from JiayiFeng/refine_grad_clip_api
update gradient clip api
2 parents 1575c2c + 5fc498e commit a173fa7

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

python/paddle/v2/fluid/clip.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
16+
1517
import functools
1618
import layers
1719
import framework
1820
from . import core
1921

2022
__all__ = [
21-
'GradientClipByValue',
2223
'ErrorClipByValue',
24+
'GradientClipByValue',
25+
'GradientClipByNorm',
26+
'GradientClipByGlobalNorm',
2327
'append_gradient_clip_ops',
2428
'error_clip_callback',
2529
]
@@ -155,10 +159,11 @@ def create_operators(self, param, grad):
155159
return param, new_grad
156160

157161

158-
def gradient_clip_by_global_norm(clip_norm,
159-
param_list=None,
160-
group_name="default_group",
161-
program=None):
162+
def set_gradient_clip(clip, param_list=None, program=None):
163+
if not isinstance(clip, BaseGradientClipAttr):
164+
raise TypeError(
165+
"'clip' should be an instance of BaseGradientClipAttr's derived class"
166+
)
162167
if program is None:
163168
program = framework.default_main_program()
164169
if param_list is None:
@@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm,
171176
)
172177

173178
for param in param_list:
174-
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
175-
group_name)
179+
param.gradient_clip_attr = copy.deepcopy(clip)
176180

177181

178182
def append_gradient_clip_ops(param_grad):

python/paddle/v2/fluid/tests/test_gradient_clip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
4141

4242
with fluid.program_guard(main_program=prog_clip):
43-
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
43+
fluid.clip.set_gradient_clip(
44+
fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
4445
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
4546

4647
grad_list = [elem[1] for elem in p_g]

0 commit comments

Comments
 (0)