Skip to content

Commit 4772b78

Browse files
committed
add config_helper.
1 parent dfc5d1f commit 4772b78

File tree

8 files changed

+142
-31
lines changed

8 files changed

+142
-31
lines changed

doc/api/v2/config/layer.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,11 @@ cos_sim
372372
.. autoclass:: paddle.v2.layer.cos_sim
373373
:noindex:
374374

375+
l2_distance
376+
-----------
377+
.. autoclass:: paddle.v2.layer.l2_distance
378+
:noindex:
379+
375380
trans
376381
-----
377382
.. autoclass:: paddle.v2.layer.trans

paddle/gserver/layers/L2DistanceLayer.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ bool L2DistanceLayer::init(const LayerMap& layerMap,
2525
/* Initialize the basic parent class */
2626
Layer::init(layerMap, parameterMap);
2727

28-
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2 distance layer accepts two and "
28+
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and "
2929
<< "only two inputs.";
30-
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2 distance"
30+
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer "
3131
<< "is fixed to be 1.";
3232

3333
return true;
@@ -41,9 +41,9 @@ void L2DistanceLayer::forward(PassType passType) {
4141

4242
CHECK(inV1 && inV2);
4343
CHECK_EQ(inV1->getHeight(), inV2->getHeight())
44-
<< "The height of two inputs to this layer must be the same.";
44+
<< "The height of two inputs of this layer must be the same.";
4545
CHECK_EQ(inV1->getWidth(), inV2->getWidth())
46-
<< "The width of two inputs to this layer must be the same.";
46+
<< "The width of two inputs of this layer must be the same.";
4747

4848
int batchSize = inV1->getHeight();
4949
int output_dim = getSize();
@@ -66,22 +66,21 @@ void L2DistanceLayer::forward(PassType passType) {
6666
void L2DistanceLayer::backward(const UpdateCallback& callback) {
6767
const auto outG = getOutputGrad();
6868
const auto outV = getOutputValue();
69-
const auto inV1 = getInputValue(0);
70-
const auto inV2 = getInputValue(1);
69+
CHECK(outG && outV);
70+
7171
auto inGrad1 = getInputGrad(0);
7272
auto inGrad2 = getInputGrad(1);
73-
CHECK(outG && outV && inV1 && inV2 && inGrad1 && inGrad2);
7473

7574
{
7675
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
7776

78-
outV->scalarDiv(*outV, 1.);
79-
outV->dotMul(*outG, *outV);
80-
81-
if (inGrad1) {
82-
inGrad1->addRowScale(0, *inputSub_, *outV);
77+
if (inGrad1 || inGrad2) {
78+
outV->scalarDiv(*outV, 1.);
79+
outV->dotMul(*outG, *outV);
8380
}
8481

82+
if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV);
83+
8584
if (inGrad2) {
8685
inputSub_->mulScalar(-1.);
8786
inGrad2->addRowScale(0, *inputSub_, *outV);

paddle/gserver/layers/L2DistanceLayer.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ limitations under the License. */
1616

1717
#include "Layer.h"
1818
#include "paddle/math/Matrix.h"
19-
#include "paddle/utils/ThreadLocal.h"
2019

2120
namespace paddle {
2221

2322
/**
24-
* @brief A layer for calculating l2 distance between the two input vectors.
23+
* @brief The layer calculates the l2 distance between two input vectors.
2524
* \f[
2625
* f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)}
2726
* \f]
@@ -30,13 +29,12 @@ namespace paddle {
3029
* - Input2: A vector (batchSize * dataDim)
3130
* - Output: A vector (batchSize * 1)
3231
*
33-
* The config file api is l2_distance.
32+
* The configuration api is: l2_distance_layer.
3433
*/
3534

3635
class L2DistanceLayer : public Layer {
3736
public:
3837
explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {}
39-
4038
~L2DistanceLayer() {}
4139

4240
bool init(const LayerMap& layerMap,
@@ -46,7 +44,8 @@ class L2DistanceLayer : public Layer {
4644
void backward(const UpdateCallback& callback = nullptr) override;
4745

4846
private:
49-
// Store result of subtracting Input2 from Input1.
47+
// Store the result of subtracting Input2 from Input1 in forward computation,
48+
// which will be reused in backward computation.
5049
MatrixPtr inputSub_;
5150
};
5251

python/paddle/trainer/config_parser.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,6 +3330,18 @@ def __init__(self, name, inputs, **xargs):
33303330
self.set_layer_size(input_layer.size)
33313331

33323332

3333+
@config_layer('cos')
3334+
class CosSimLayer(LayerBase):
3335+
def __init__(self, name, inputs, cos_scale=1, device=None):
3336+
super(CosSimLayer, self).__init__(
3337+
name, 'cos', 1, inputs=inputs, device=device)
3338+
config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs')
3339+
config_assert(
3340+
self.get_input_layer(0).size == self.get_input_layer(1).size,
3341+
'inputs of CosSimLayer must have same dim')
3342+
self.config.cos_scale = cos_scale
3343+
3344+
33333345
@config_layer('cos_vm')
33343346
class CosSimVecMatLayer(LayerBase):
33353347
def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
@@ -3343,6 +3355,20 @@ def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
33433355
'Wrong input size for CosSimVecMatLayer')
33443356

33453357

3358+
@config_layer('l2_distance')
3359+
class L2DistanceLayer(LayerBase):
3360+
def __init__(self, name, inputs, device=None):
3361+
super(L2DistanceLayer, self).__init__(
3362+
name, 'l2_distance', 1, inputs=inputs, device=device)
3363+
config_assert(
3364+
len(self.inputs) == 2, ('The L2DistanceLayer must have '
3365+
'and only have 2 inputs.'))
3366+
config_assert(
3367+
self.get_input_layer(0).size == self.get_input_layer(1).size,
3368+
('Two inputs of the L2DistanceLayer must have '
3369+
'the same dimensionality.'))
3370+
3371+
33463372
@config_layer('sampling_id')
33473373
class SamplingIdLayer(LayerBase):
33483374
def __init__(self, name, inputs, device=None):
@@ -3384,18 +3410,6 @@ def __init__(self,
33843410
self.create_bias_parameter(bias, self.config.size)
33853411

33863412

3387-
@config_layer('cos')
3388-
class CosSimLayer(LayerBase):
3389-
def __init__(self, name, inputs, cos_scale=1, device=None):
3390-
super(CosSimLayer, self).__init__(
3391-
name, 'cos', 1, inputs=inputs, device=device)
3392-
config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs')
3393-
config_assert(
3394-
self.get_input_layer(0).size == self.get_input_layer(1).size,
3395-
'inputs of CosSimLayer must have same dim')
3396-
self.config.cos_scale = cos_scale
3397-
3398-
33993413
@config_layer('tensor')
34003414
class TensorLayer(LayerBase):
34013415
def __init__(self, name, size, inputs, bias=True, **xargs):

python/paddle/trainer_config_helpers/layers.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
'last_seq',
5252
'first_seq',
5353
'cos_sim',
54+
'l2_distance_layer',
5455
'hsigmoid',
5556
'conv_projection',
5657
'square_error_cost',
@@ -167,6 +168,7 @@ class LayerType(object):
167168
COST = 'cost'
168169
COSINE_SIM_VEC = 'cos_vm'
169170
COSINE_SIM = 'cos'
171+
L2_DISTANCE = 'l2_distance'
170172
HSIGMOID = 'hsigmoid'
171173
CONV_LAYER = 'conv'
172174
CONVTRANS_LAYER = 'convt'
@@ -2332,6 +2334,51 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
23322334
return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b], size=size)
23332335

23342336

2337+
@wrap_name_default()
2338+
@layer_support()
2339+
def l2_distance_layer(x, y, name=None, layer_attr=None):
2340+
"""
2341+
This layer calculate and return the Euclidean distance between two input
2342+
vectors a and b. The equation is as follows:
2343+
2344+
.. math::
2345+
l2_distance(\\mathbf{x}, \\mathbf{y}) = \\sqrt{\\sum_{i=1}^D(x_i - y_i)}
2346+
2347+
The output size of this layer is fixed to be 1. Note that the above
2348+
computation is for one sample. Multiple samples are processed in one batch.
2349+
2350+
The example usage is:
2351+
2352+
.. code-block:: python
2353+
2354+
l2_sim = l2_distance(x=layer1, y=layer2)
2355+
2356+
:param name: The name of this layer. It is optional.
2357+
:type name: basestring
2358+
:param x: The first input x for this layer, whose output is a matrix with
2359+
dimensionality N x D. N is the sample number in a mini-batch.
2360+
D is the dimensionality of x's output.
2361+
:type x: LayerOutput
2362+
:param y: The second input y for this layer, whose output is a matrix with
2363+
dimensionality N x D. N is the sample number in a mini-batch.
2364+
D is the dimensionality of y's output.
2365+
:type y: LayerOutput
2366+
:param layer_attr: The extra layer attributes, for example, drop rate.
2367+
See ExtraLayerAttribute for more details.
2368+
:type layer_attr: ExtraLayerAttribute
2369+
:return: The returned LayerOutput object.
2370+
:rtype: LayerOutput
2371+
"""
2372+
2373+
assert isinstance(x, LayerOutput) and isinstance(x, LayerOutput)
2374+
Layer(
2375+
name=name,
2376+
type=LayerType.L2_DISTANCE,
2377+
inputs=[x.name, x.name],
2378+
**ExtraLayerAttribute.to_kwargs(layer_attr))
2379+
return LayerOutput(name, LayerType.L2_DISTANCE, parents=[x, y], size=1)
2380+
2381+
23352382
@wrap_name_default()
23362383
@wrap_bias_attr_default(has_bias=True)
23372384
@wrap_param_attr_default()
@@ -3867,7 +3914,7 @@ def recurrent_layer(input,
38673914
:type input: LayerOutput
38683915
:param act: Activation type. TanhActivation is the default activation.
38693916
:type act: BaseActivation
3870-
:param bias_attr: The parameter attribute for bias. If this parameter is set to
3917+
:param bias_attr: The parameter attribute for bias. If this parameter is set to
38713918
False or an object whose type is not ParameterAttribute,
38723919
no bias is defined. If the parameter is set to True,
38733920
the bias is initialized to zero.

python/paddle/trainer_config_helpers/tests/configs/file_list.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
1010
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
1111
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
1212
test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer
13-
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer)
13+
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
14+
test_scale_sub_region_layer test_l2_distance_layer)
1415

1516
export whole_configs=(test_split_datasource)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
type: "nn"
2+
layers {
3+
name: "x"
4+
type: "data"
5+
size: 128
6+
active_type: ""
7+
}
8+
layers {
9+
name: "y"
10+
type: "data"
11+
size: 128
12+
active_type: ""
13+
}
14+
layers {
15+
name: "__l2_distance_layer_0__"
16+
type: "l2_distance"
17+
size: 1
18+
active_type: ""
19+
inputs {
20+
input_layer_name: "x"
21+
}
22+
inputs {
23+
input_layer_name: "x"
24+
}
25+
}
26+
input_layer_names: "x"
27+
input_layer_names: "y"
28+
output_layer_names: "__l2_distance_layer_0__"
29+
sub_models {
30+
name: "root"
31+
layer_names: "x"
32+
layer_names: "y"
33+
layer_names: "__l2_distance_layer_0__"
34+
input_layer_names: "x"
35+
input_layer_names: "y"
36+
output_layer_names: "__l2_distance_layer_0__"
37+
is_recurrent_layer_group: false
38+
}
39+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from paddle.trainer_config_helpers import *
2+
3+
outputs(
4+
l2_distance_layer(
5+
x=data_layer(
6+
name='x', size=128), y=data_layer(
7+
name='y', size=128)))

0 commit comments

Comments
 (0)