Skip to content

Commit 74523c4

Browse files
committed
enhance regularizer.py
1 parent 0d49b92 commit 74523c4

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

python/paddle/fluid/regularizer.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import framework
16+
from . import core
1617

1718
__all__ = [
1819
'append_regularization_ops',
@@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
4647
regularization_term = None
4748
if param.regularizer is not None:
4849
# Add variable for regularization term in grad block
49-
regularization_term = param.regularizer(param, grad.block)
50+
regularization_term = param.regularizer(param, grad, grad.block)
5051
elif regularization is not None:
51-
regularization_term = regularization(param, grad.block)
52+
regularization_term = regularization(param, grad, grad.block)
5253

5354
# If no gradient or no regularization specified,
5455
# then we don't need to do anything
@@ -82,7 +83,7 @@ class WeightDecayRegularizer(object):
8283
def __init__(self):
8384
pass
8485

85-
def __call__(self, param, block):
86+
def __call__(self, param, grad, block):
8687
"""Add corresponding weight decay operations to the network
8788
"""
8889
raise NotImplementedError()
@@ -102,7 +103,7 @@ def __init__(self, regularization_coeff=0.0):
102103
super(L2DecayRegularizer, self).__init__()
103104
self._regularization_coeff = regularization_coeff
104105

105-
def __call__(self, param, block):
106+
def __call__(self, param, grad, block):
106107
"""Add L2 weight decay ops to network
107108
108109
Adds L2 weight decay ops.
@@ -117,8 +118,23 @@ def __call__(self, param, block):
117118
"""
118119
assert isinstance(param, framework.Parameter)
119120
assert isinstance(block, framework.Block)
121+
120122
decay = block.create_var(
121123
dtype="float32", shape=param.shape, lod_level=param.lod_level)
124+
125+
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
126+
decay = block.create_var(
127+
dtype="float32",
128+
shape=param.shape,
129+
type=core.VarDesc.VarType.SELECTED_ROWS)
130+
block.append_op(
131+
type='lookup_table',
132+
inputs={'W': param,
133+
'Ids': grad},
134+
outputs={'Out': decay},
135+
attrs={'is_sparse': True})
136+
param = decay
137+
122138
# Append Op to calculate decay
123139
block.append_op(
124140
type='scale',
@@ -141,7 +157,7 @@ def __init__(self, regularization_coeff=0.0):
141157
super(L1DecayRegularizer, self).__init__()
142158
self._regularization_coeff = regularization_coeff
143159

144-
def __call__(self, param, block):
160+
def __call__(self, param, grad, block):
145161
"""Add L1 weight decay ops to network
146162
147163
Adds L1 weight decay ops.
@@ -158,6 +174,20 @@ def __call__(self, param, block):
158174
assert isinstance(block, framework.Block)
159175
decay = block.create_var(
160176
dtype="float32", shape=param.shape, lod_level=param.lod_level)
177+
178+
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
179+
# add concat_rows
180+
decay = block.create_var(
181+
dtype="float32",
182+
shape=param.shape,
183+
type=core.VarDesc.VarType.SELECTED_ROWS)
184+
block.append_op(
185+
type='lookup_table',
186+
inputs={'W': param,
187+
'Ids': grad},
188+
outputs={'Out': decay},
189+
attrs={'is_sparse': True})
190+
161191
# Append sign op
162192
block.append_op(
163193
type='sign', inputs={"X": param}, outputs={"Out": decay})

0 commit comments

Comments
 (0)