Skip to content

Commit a8f6636

Browse files
authored
Merge pull request #13524 from sneaxiy/fix_api_kwargs
Remove kwargs in elementwise layers and scale layer
2 parents 175a2ef + 70e70d7 commit a8f6636

File tree

9 files changed

+150
-27
lines changed

9 files changed

+150
-27
lines changed

paddle/fluid/API.spec

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, key
178178
paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None))
179179
paddle.fluid.layers.expand ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,))
180180
paddle.fluid.layers.sequence_concat ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
181+
paddle.fluid.layers.scale ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale', 'act', 'name'], varargs=None, keywords=None, defaults=(1.0, 0.0, True, None, None))
182+
paddle.fluid.layers.elementwise_add ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
183+
paddle.fluid.layers.elementwise_div ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
184+
paddle.fluid.layers.elementwise_sub ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
185+
paddle.fluid.layers.elementwise_mul ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
186+
paddle.fluid.layers.elementwise_max ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
187+
paddle.fluid.layers.elementwise_min ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
188+
paddle.fluid.layers.elementwise_pow ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, False, None, None))
181189
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
182190
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
183191
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
@@ -242,15 +250,7 @@ paddle.fluid.layers.Print ArgSpec(args=['input', 'first_n', 'message', 'summariz
242250
paddle.fluid.layers.is_empty ArgSpec(args=['x', 'cond'], varargs=None, keywords='ignored', defaults=(None,))
243251
paddle.fluid.layers.mean ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
244252
paddle.fluid.layers.mul ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
245-
paddle.fluid.layers.scale ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
246253
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
247-
paddle.fluid.layers.elementwise_add ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
248-
paddle.fluid.layers.elementwise_div ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
249-
paddle.fluid.layers.elementwise_sub ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
250-
paddle.fluid.layers.elementwise_mul ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
251-
paddle.fluid.layers.elementwise_max ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
252-
paddle.fluid.layers.elementwise_min ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
253-
paddle.fluid.layers.elementwise_pow ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
254254
paddle.fluid.layers.clip ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
255255
paddle.fluid.layers.clip_by_norm ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
256256
paddle.fluid.layers.logical_and ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)

paddle/fluid/operators/scale_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,15 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
4646
AddComment(R"DOC(
4747
**Scale operator**
4848
49-
Multiply the input tensor with a float scalar to scale the input tensor.
49+
Apply scaling and bias addition to the input tensor.
5050
51-
$$Out = scale*X$$
51+
if bias_after_scale=True:
52+
53+
$$Out = scale*X + bias$$
54+
55+
else:
56+
57+
$$Out = scale*(X + bias)$$
5258
)DOC");
5359
AddAttr<float>("scale", "The scaling factor of the scale operator.")
5460
.SetDefault(1.0);

python/paddle/fluid/framework.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,8 @@ def get_op_proto(self, type):
489489
def generated_op_attr_names():
490490
return {
491491
core.op_proto_and_checker_maker.kOpRoleAttrName(),
492-
core.op_proto_and_checker_maker.kOpRoleVarAttrName()
492+
core.op_proto_and_checker_maker.kOpRoleVarAttrName(),
493+
core.op_proto_and_checker_maker.kOpNameScopeAttrName()
493494
}
494495

495496

python/paddle/fluid/layers/layer_function_generator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def escape_math(text):
6161
_two_dollar_pattern_.sub(r"!!\1!!", text)))
6262

6363

64-
def _generate_doc_string_(op_proto):
64+
def _generate_doc_string_(op_proto, additional_args_lines=None):
6565
"""
6666
Generate docstring by OpProto
6767
@@ -101,6 +101,13 @@ def _generate_doc_string_(op_proto):
101101
buf.write(escape_math(each_attr.comment))
102102
buf.write('\n')
103103

104+
if additional_args_lines is not None:
105+
for line in additional_args_lines:
106+
line = line.strip()
107+
buf.write(' ')
108+
buf.write(line)
109+
buf.write('\n')
110+
104111
if len(op_proto.outputs) != 0:
105112
buf.write('\nReturns:\n')
106113
buf.write(' ')

python/paddle/fluid/layers/learning_rate_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def noam_decay(d_model, warmup_steps):
6868

6969
a = global_step**-0.5
7070
b = (warmup_steps**-1.5) * global_step
71-
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
71+
lr_value = (d_model**-0.5) * nn.elementwise_min(a, b)
7272

7373
return lr_value
7474

@@ -241,7 +241,7 @@ def polynomial_decay(learning_rate,
241241
else:
242242
decay_steps_var = tensor.fill_constant(
243243
shape=[1], dtype='float32', value=float(decay_steps))
244-
global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
244+
global_step = nn.elementwise_min(x=global_step, y=decay_steps_var)
245245

246246
decayed_lr = (learning_rate - end_learning_rate) * \
247247
((1 - global_step / decay_steps) ** power) + end_learning_rate

python/paddle/fluid/layers/nn.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import numpy as np
2121
from ..layer_helper import LayerHelper
2222
from ..initializer import Normal, Constant
23-
from ..framework import Variable
23+
from ..framework import Variable, OpProtoHolder
2424
from ..param_attr import ParamAttr
25-
from .layer_function_generator import autodoc, templatedoc
25+
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
2626
from .tensor import concat
2727
from . import utils
2828
from .. import unique_name
@@ -125,6 +125,14 @@
125125
'sequence_enumerate',
126126
'expand',
127127
'sequence_concat',
128+
'scale',
129+
'elementwise_add',
130+
'elementwise_div',
131+
'elementwise_sub',
132+
'elementwise_mul',
133+
'elementwise_max',
134+
'elementwise_min',
135+
'elementwise_pow',
128136
]
129137

130138

@@ -3614,7 +3622,7 @@ def __check_input(x, y):
36143622
attrs={
36153623
'transpose_X': transpose_x,
36163624
'transpose_Y': transpose_y,
3617-
'alpha': alpha,
3625+
'alpha': float(alpha),
36183626
})
36193627
return out
36203628

@@ -6453,3 +6461,105 @@ def expand(x, expand_times, name=None):
64536461
outputs={'Out': out},
64546462
attrs={'expand_times': expand_times})
64556463
return out
6464+
6465+
6466+
def _elementwise_op(helper):
6467+
op_type = helper.layer_type
6468+
x = helper.kwargs.get('x', None)
6469+
y = helper.kwargs.get('y', None)
6470+
assert x is not None, 'x cannot be None in {}'.format(op_type)
6471+
assert y is not None, 'y cannot be None in {}'.format(op_type)
6472+
axis = helper.kwargs.get('axis', -1)
6473+
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
6474+
name = helper.kwargs.get('name', None)
6475+
if name is None:
6476+
out = helper.create_tmp_variable(dtype=x.dtype)
6477+
else:
6478+
out = helper.create_variable(
6479+
name=name, dtype=x.dtype, persistable=False)
6480+
6481+
helper.append_op(
6482+
type=op_type,
6483+
inputs={'X': x,
6484+
'Y': y},
6485+
outputs={'Out': out},
6486+
attrs={'axis': axis,
6487+
'use_mkldnn': use_mkldnn})
6488+
return helper.append_activation(out)
6489+
6490+
6491+
@templatedoc()
6492+
def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
6493+
"""
6494+
${comment}
6495+
6496+
Args:
6497+
x(${x_type}): ${x_comment}
6498+
scale(${scale_type}): ${scale_comment}
6499+
bias(${bias_type}): ${bias_comment}
6500+
bias_after_scale(${bias_after_scale_type}): ${bias_after_scale_comment}
6501+
act(basestring|None): Activation applied to the output.
6502+
name(basestring|None): Name of the output.
6503+
6504+
Returns:
6505+
out(${out_type}): ${out_comment}
6506+
"""
6507+
6508+
helper = LayerHelper('scale', **locals())
6509+
if name is None:
6510+
out = helper.create_tmp_variable(dtype=x.dtype)
6511+
else:
6512+
out = helper.create_variable(
6513+
name=name, dtype=x.dtype, persistable=False)
6514+
6515+
helper.append_op(
6516+
type='scale',
6517+
inputs={'X': x},
6518+
outputs={'Out': out},
6519+
attrs={
6520+
'scale': float(scale),
6521+
'bias': float(bias),
6522+
'bias_after_scale': bias_after_scale
6523+
})
6524+
return helper.append_activation(out)
6525+
6526+
6527+
def elementwise_add(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6528+
return _elementwise_op(LayerHelper('elementwise_add', **locals()))
6529+
6530+
6531+
def elementwise_div(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6532+
return _elementwise_op(LayerHelper('elementwise_div', **locals()))
6533+
6534+
6535+
def elementwise_sub(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6536+
return _elementwise_op(LayerHelper('elementwise_sub', **locals()))
6537+
6538+
6539+
def elementwise_mul(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6540+
return _elementwise_op(LayerHelper('elementwise_mul', **locals()))
6541+
6542+
6543+
def elementwise_max(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6544+
return _elementwise_op(LayerHelper('elementwise_max', **locals()))
6545+
6546+
6547+
def elementwise_min(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6548+
return _elementwise_op(LayerHelper('elementwise_min', **locals()))
6549+
6550+
6551+
def elementwise_pow(x, y, axis=-1, use_mkldnn=False, act=None, name=None):
6552+
return _elementwise_op(LayerHelper('elementwise_pow', **locals()))
6553+
6554+
6555+
for func in [
6556+
elementwise_add, elementwise_div, elementwise_sub, elementwise_mul,
6557+
elementwise_max, elementwise_min, elementwise_pow
6558+
]:
6559+
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
6560+
func.__doc__ = _generate_doc_string_(
6561+
op_proto,
6562+
additional_args_lines=[
6563+
"act (basestring|None): Activation applied to the output.",
6564+
"name (basestring|None): Name of the output."
6565+
])

python/paddle/fluid/layers/ops.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,7 @@
3737
__all__ = [
3838
'mean',
3939
'mul',
40-
'scale',
4140
'sigmoid_cross_entropy_with_logits',
42-
'elementwise_add',
43-
'elementwise_div',
44-
'elementwise_sub',
45-
'elementwise_mul',
46-
'elementwise_max',
47-
'elementwise_min',
48-
'elementwise_pow',
4941
'clip',
5042
'clip_by_norm',
5143
'logical_and',
@@ -66,6 +58,11 @@
6658
for _OP in set(__all__):
6759
globals()[_OP] = generate_layer_fn(_OP)
6860

61+
# It is a hot fix in some unittest using:
62+
# fluid.layers.scale(x=x, scale=10.0, out=out_var)
63+
# e.g.: test_program_code.py, test_dist_train.py
64+
globals()['_scale'] = generate_layer_fn('scale')
65+
6966
__all__ += __activations_noattr__
7067

7168
for _OP in set(__activations_noattr__):

python/paddle/fluid/tests/unittests/test_dist_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from paddle.fluid.layers.io import ListenAndServ
2828
from paddle.fluid.layers.io import Recv
2929
from paddle.fluid.layers.io import Send
30+
import paddle.fluid.layers.ops as ops
3031

3132
from paddle.fluid import core
3233

@@ -89,7 +90,7 @@ def init_serv(self, place):
8990
name="X",
9091
append_batch_size=False)
9192
fluid.initializer.Constant(value=1.0)(x, main.global_block())
92-
layers.scale(x=x, scale=10.0, out=out_var)
93+
ops._scale(x=x, scale=10.0, out=out_var)
9394

9495
self.server_exe = fluid.Executor(place)
9596
self.server_exe.run(main)

python/paddle/fluid/tests/unittests/test_program_code.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from paddle.fluid.layers.io import ListenAndServ
2626
from paddle.fluid.layers.io import Recv
2727
from paddle.fluid.layers.io import Send
28+
import paddle.fluid.layers.ops as ops
2829

2930
from paddle.fluid.transpiler.details import program_to_code
3031

@@ -52,7 +53,7 @@ def init_serv(self, place):
5253
name="X",
5354
append_batch_size=False)
5455
fluid.initializer.Constant(value=1.0)(x, main.global_block())
55-
layers.scale(x=x, scale=10.0, out=out_var)
56+
ops._scale(x=x, scale=10.0, out=out_var)
5657

5758
program_to_code(main)
5859

0 commit comments

Comments
 (0)