13
13
# limitations under the License.
14
14
15
15
import framework
16
+ from . import core
16
17
17
18
__all__ = [
18
19
'append_regularization_ops' ,
@@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
46
47
regularization_term = None
47
48
if param .regularizer is not None :
48
49
# Add variable for regularization term in grad block
49
- regularization_term = param .regularizer (param , grad .block )
50
+ regularization_term = param .regularizer (param , grad , grad .block )
50
51
elif regularization is not None :
51
- regularization_term = regularization (param , grad .block )
52
+ regularization_term = regularization (param , grad , grad .block )
52
53
53
54
# If no gradient or no regularization specified,
54
55
# then we don't need to do anything
@@ -82,7 +83,7 @@ class WeightDecayRegularizer(object):
82
83
def __init__ (self ):
83
84
pass
84
85
85
- def __call__ (self , param , block ):
86
+ def __call__ (self , param , grad , block ):
86
87
"""Add corresponding weight decay operations to the network
87
88
"""
88
89
raise NotImplementedError ()
@@ -102,7 +103,7 @@ def __init__(self, regularization_coeff=0.0):
102
103
super (L2DecayRegularizer , self ).__init__ ()
103
104
self ._regularization_coeff = regularization_coeff
104
105
105
- def __call__ (self , param , block ):
106
+ def __call__ (self , param , grad , block ):
106
107
"""Add L2 weight decay ops to network
107
108
108
109
Adds L2 weight decay ops.
@@ -117,8 +118,23 @@ def __call__(self, param, block):
117
118
"""
118
119
assert isinstance (param , framework .Parameter )
119
120
assert isinstance (block , framework .Block )
121
+
120
122
decay = block .create_var (
121
123
dtype = "float32" , shape = param .shape , lod_level = param .lod_level )
124
+
125
+ if grad .type == core .VarDesc .VarType .SELECTED_ROWS :
126
+ decay = block .create_var (
127
+ dtype = "float32" ,
128
+ shape = param .shape ,
129
+ type = core .VarDesc .VarType .SELECTED_ROWS )
130
+ block .append_op (
131
+ type = 'lookup_table' ,
132
+ inputs = {'W' : param ,
133
+ 'Ids' : grad },
134
+ outputs = {'Out' : decay },
135
+ attrs = {'is_sparse' : True })
136
+ param = decay
137
+
122
138
# Append Op to calculate decay
123
139
block .append_op (
124
140
type = 'scale' ,
@@ -141,7 +157,7 @@ def __init__(self, regularization_coeff=0.0):
141
157
super (L1DecayRegularizer , self ).__init__ ()
142
158
self ._regularization_coeff = regularization_coeff
143
159
144
- def __call__ (self , param , block ):
160
+ def __call__ (self , param , grad , block ):
145
161
"""Add L1 weight decay ops to network
146
162
147
163
Adds L1 weight decay ops.
@@ -158,6 +174,19 @@ def __call__(self, param, block):
158
174
assert isinstance (block , framework .Block )
159
175
decay = block .create_var (
160
176
dtype = "float32" , shape = param .shape , lod_level = param .lod_level )
177
+
178
+ if grad .type == core .VarDesc .VarType .SELECTED_ROWS :
179
+ decay = block .create_var (
180
+ dtype = "float32" ,
181
+ shape = param .shape ,
182
+ type = core .VarDesc .VarType .SELECTED_ROWS )
183
+ block .append_op (
184
+ type = 'lookup_table' ,
185
+ inputs = {'W' : param ,
186
+ 'Ids' : grad },
187
+ outputs = {'Out' : decay },
188
+ attrs = {'is_sparse' : True })
189
+
161
190
# Append sign op
162
191
block .append_op (
163
192
type = 'sign' , inputs = {"X" : param }, outputs = {"Out" : decay })
0 commit comments