Skip to content

Commit 065ffcc

Browse files
authored
fix dgcclipnorm bug test=develop (#16629)
1 parent 7964366 commit 065ffcc

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

paddle/fluid/operators/dgc_clip_by_norm_op.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,21 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& context) const override {
2626
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
27-
if (static_cast<int>(rampup_begin_step) >= 0) {
28-
auto current_step_tensor =
29-
context.Input<framework::Tensor>("current_step");
30-
auto* current_step = current_step_tensor->data<T>();
31-
32-
if (static_cast<int>(*current_step) <
33-
static_cast<int>(rampup_begin_step)) {
34-
VLOG(10) << "current_step:" << *current_step
35-
<< " < rampup_begin_step:" << rampup_begin_step
36-
<< " so does't use dgc_clip_by_norm";
37-
return;
38-
}
27+
if (static_cast<int>(rampup_begin_step) < 0) {
28+
return;
29+
}
30+
31+
auto current_step_tensor = context.Input<framework::Tensor>("current_step");
32+
auto* current_step = current_step_tensor->data<T>();
33+
34+
VLOG(10) << "current_step:" << *current_step
35+
<< ", rampup_begin_step:" << rampup_begin_step;
36+
37+
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
38+
VLOG(10) << "current_step:" << *current_step
39+
<< " < rampup_begin_step:" << rampup_begin_step
40+
<< " so does't use dgc_clip_by_norm";
41+
return;
3942
}
4043

4144
return ClipByNormKernel<DeviceContext, T>::Compute(context);

python/paddle/fluid/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ def _clip_by_norm(self, x, max_norm, name=None):
832832
type=x.type, name=name, dtype=x.dtype, persistable=False)
833833

834834
helper.append_op(
835-
type="clip_by_norm",
835+
type="dgc_clip_by_norm",
836836
inputs={"X": x,
837837
"current_step": self._global_step_var},
838838
attrs={
@@ -845,7 +845,7 @@ def _clip_by_norm(self, x, max_norm, name=None):
845845
def _append_clip_norm(self, grad_var, clip_norm):
846846
with grad_var.block.program._backward_role_guard():
847847
return self._clip_by_norm(
848-
x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC")
848+
x=grad_var, max_norm=clip_norm, name=grad_var.name)
849849

850850
def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
851851
encoded_var):

0 commit comments

Comments
 (0)