@@ -185,7 +185,17 @@ class WeightNormParamAttr(ParamAttr):
185
185
186
186
Args:
187
187
dim(list): The parameter's name. Default None.
188
- kwargs: Any field in ParamAttr. Default None.
188
+ name(str): The parameter's name. Default None.
189
+ initializer(Initializer): The method to initial this parameter. Default None.
190
+ learning_rate(float): The parameter's learning rate. The learning rate when
191
+ optimize is :math:`global\_lr * parameter\_lr * scheduler\_factor`.
192
+ Default 1.0.
193
+ regularizer(WeightDecayRegularizer): Regularization factor. Default None.
194
+ trainable(bool): Whether this parameter is trainable. Default True.
195
+ gradient_clip(BaseGradientClipAttr): The method to clip this parameter's
196
+ gradient. Default None.
197
+ do_model_average(bool): Whether this parameter should do model average.
198
+ Default False.
189
199
190
200
Examples:
191
201
.. code-block:: python
@@ -204,6 +214,21 @@ class WeightNormParamAttr(ParamAttr):
204
214
# these paramters for inference.
205
215
params_with_weight_norm = []
206
216
207
- def __init__ (self , dim = None , ** kwargs ):
208
- super (WeightNormParamAttr , self ).__init__ (** kwargs )
217
+ def __init__ (self ,
218
+ dim = None ,
219
+ name = None ,
220
+ initializer = None ,
221
+ learning_rate = 1.0 ,
222
+ regularizer = None ,
223
+ trainable = True ,
224
+ gradient_clip = None ,
225
+ do_model_average = False ):
226
+ super (WeightNormParamAttr , self ).__init__ (
227
+ name = name ,
228
+ initializer = initializer ,
229
+ learning_rate = learning_rate ,
230
+ regularizer = regularizer ,
231
+ trainable = trainable ,
232
+ gradient_clip = gradient_clip ,
233
+ do_model_average = do_model_average )
209
234
self .dim = dim
0 commit comments