Skip to content

Commit dfc5d1f

Browse files
committed
add the l2 distance layer.
1 parent e0e3a8a commit dfc5d1f

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "L2DistanceLayer.h"
16+
#include "paddle/utils/Logging.h"
17+
#include "paddle/utils/Stat.h"
18+
19+
namespace paddle {
20+
21+
REGISTER_LAYER(l2_distance, L2DistanceLayer);
22+
23+
bool L2DistanceLayer::init(const LayerMap& layerMap,
24+
const ParameterMap& parameterMap) {
25+
/* Initialize the basic parent class */
26+
Layer::init(layerMap, parameterMap);
27+
28+
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2 distance layer accepts two and "
29+
<< "only two inputs.";
30+
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2 distance"
31+
<< "is fixed to be 1.";
32+
33+
return true;
34+
}
35+
36+
void L2DistanceLayer::forward(PassType passType) {
37+
Layer::forward(passType);
38+
39+
const auto inV1 = getInputValue(0);
40+
const auto inV2 = getInputValue(1);
41+
42+
CHECK(inV1 && inV2);
43+
CHECK_EQ(inV1->getHeight(), inV2->getHeight())
44+
<< "The height of two inputs to this layer must be the same.";
45+
CHECK_EQ(inV1->getWidth(), inV2->getWidth())
46+
<< "The width of two inputs to this layer must be the same.";
47+
48+
int batchSize = inV1->getHeight();
49+
int output_dim = getSize();
50+
{
51+
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
52+
reserveOutput(batchSize, output_dim);
53+
auto outV = getOutputValue();
54+
CHECK(outV) << "The output matrix should not be null.";
55+
56+
Matrix::resizeOrCreate(
57+
inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_);
58+
59+
inputSub_->assign(*inV1);
60+
inputSub_->sub(*inV2);
61+
outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0);
62+
outV->sqrt2(*outV);
63+
}
64+
}
65+
66+
void L2DistanceLayer::backward(const UpdateCallback& callback) {
67+
const auto outG = getOutputGrad();
68+
const auto outV = getOutputValue();
69+
const auto inV1 = getInputValue(0);
70+
const auto inV2 = getInputValue(1);
71+
auto inGrad1 = getInputGrad(0);
72+
auto inGrad2 = getInputGrad(1);
73+
CHECK(outG && outV && inV1 && inV2 && inGrad1 && inGrad2);
74+
75+
{
76+
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
77+
78+
outV->scalarDiv(*outV, 1.);
79+
outV->dotMul(*outG, *outV);
80+
81+
if (inGrad1) {
82+
inGrad1->addRowScale(0, *inputSub_, *outV);
83+
}
84+
85+
if (inGrad2) {
86+
inputSub_->mulScalar(-1.);
87+
inGrad2->addRowScale(0, *inputSub_, *outV);
88+
}
89+
}
90+
}
91+
92+
} // namespace paddle
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "Layer.h"
18+
#include "paddle/math/Matrix.h"
19+
#include "paddle/utils/ThreadLocal.h"
20+
21+
namespace paddle {
22+
23+
/**
24+
* @brief A layer for calculating l2 distance between the two input vectors.
25+
* \f[
26+
* f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)}
27+
* \f]
28+
*
29+
* - Input1: A vector (batchSize * dataDim)
30+
* - Input2: A vector (batchSize * dataDim)
31+
* - Output: A vector (batchSize * 1)
32+
*
33+
* The config file api is l2_distance.
34+
*/
35+
36+
class L2DistanceLayer : public Layer {
37+
public:
38+
explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {}
39+
40+
~L2DistanceLayer() {}
41+
42+
bool init(const LayerMap& layerMap,
43+
const ParameterMap& parameterMap) override;
44+
45+
void forward(PassType passType) override;
46+
void backward(const UpdateCallback& callback = nullptr) override;
47+
48+
private:
49+
// Store result of subtracting Input2 from Input1.
50+
MatrixPtr inputSub_;
51+
};
52+
53+
} // namespace paddle

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ TEST(Layer, maxoutLayer) {
583583
testLayerGrad(config, "maxout", 10, false, useGpu);
584584
}
585585
}
586+
586587
void testFcLayer(string format, size_t nnz) {
587588
TestConfig config;
588589
config.biasSize = 1024;
@@ -2429,6 +2430,25 @@ TEST(Layer, ScaleSubRegionLayer) {
24292430
}
24302431
}
24312432

2433+
TEST(Layer, L2DistanceLayer) {
2434+
TestConfig config;
2435+
config.layerConfig.set_type("l2_distance");
2436+
config.layerConfig.set_size(1);
2437+
config.biasSize = 0;
2438+
2439+
const size_t input_dim = 27;
2440+
const size_t batch_size = 11;
2441+
2442+
config.inputDefs.push_back({INPUT_DATA, "layer_0", input_dim, 0});
2443+
config.inputDefs.push_back({INPUT_DATA, "layer_1", input_dim, 0});
2444+
config.layerConfig.add_inputs();
2445+
config.layerConfig.add_inputs();
2446+
2447+
for (auto useGpu : {false, true}) {
2448+
testLayerGrad(config, "l2_distance", batch_size, false, useGpu);
2449+
}
2450+
}
2451+
24322452
int main(int argc, char** argv) {
24332453
testing::InitGoogleTest(&argc, argv);
24342454
initMain(argc, argv);

0 commit comments

Comments
 (0)