Skip to content

Commit 2b841ec

Browse files
authored
Merge pull request #421 from emailweixu/scaling_projection
Add ScalingProjection
2 parents 0ba0f02 + a6ad9a1 commit 2b841ec

File tree

12 files changed

+290
-41
lines changed

12 files changed

+290
-41
lines changed

doc/ui/api/trainer_config_helpers/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ embedding_layer
191191
:members: embedding_layer
192192
:noindex:
193193

194+
scaling_projection
195+
-----------------
196+
.. automodule:: paddle.trainer_config_helpers.layers
197+
:members: scaling_projection
198+
:noindex:
199+
194200
dotmul_projection
195201
-----------------
196202
.. automodule:: paddle.trainer_config_helpers.layers

paddle/gserver/layers/CostLayer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ class SumCostLayer : public Layer {
605605
int batchSize = input->getHeight();
606606
int size = 1;
607607
resizeOutput(batchSize, size);
608-
output_.value->sumRows(*input);
608+
output_.value->sumRows(*input, /* scaleSum= */1, /* scaleDest= */0);
609609
}
610610

611611
virtual void backward(const UpdateCallback& callback = nullptr) {

paddle/gserver/layers/FullMatrixProjection.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) {
5252
}
5353

5454
hl_set_sync_flag(syncFlag);
55-
parameter_->incUpdate(callback);
55+
if (weight_->getWGrad()) {
56+
parameter_->incUpdate(callback);
57+
}
5658
}
5759

5860
} // namespace paddle
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "Projection.h"
16+
17+
namespace paddle {
18+
19+
class ScalingProjection : public Projection {
20+
public:
21+
ScalingProjection(const ProjectionConfig& config,
22+
const ParameterPtr& parameter, bool useGpu)
23+
: Projection(config, parameter, useGpu) {
24+
CHECK_EQ(parameter->getSize(), 1UL);
25+
weight_.reset(new Weight(1, 1, parameter));
26+
}
27+
28+
void forward() {
29+
CHECK(in_->value);
30+
out_->value->add(*in_->value, weight_->getW()->getElement(0, 0));
31+
}
32+
33+
void backward(const UpdateCallback& callback) {
34+
if (weight_->getWGrad()) {
35+
auto sum = Matrix::create(in_->value->getHeight(), 1, false, useGpu_);
36+
sum->sumOfProducts(*in_->value, *out_->grad,
37+
/* scaleSum= */1, /* scaleDest= */0);
38+
weight_->getWGrad()->sumCols(*sum,
39+
/* scaleSum= */1, /* scaleDest= */1);
40+
parameter_->incUpdate(callback);
41+
}
42+
if (in_->grad) {
43+
in_->grad->add(*out_->grad, weight_->getW()->getElement(0, 0));
44+
}
45+
}
46+
47+
protected:
48+
std::unique_ptr<Weight> weight_;
49+
};
50+
51+
REGISTER_PROJECTION(scaling, ScalingProjection);
52+
53+
} // namespace paddle

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ TEST(Projection, identity) {
135135
}
136136
}
137137

138+
TEST(Projection, scaling) {
139+
ProjectionConfig conf;
140+
conf.set_type("scaling");
141+
conf.set_input_size(10);
142+
conf.set_output_size(10);
143+
for (auto useGpu : {false}) {
144+
testProjectionGrad(conf, INPUT_DATA, /* parameterSize */ 1,
145+
/* batchSize */ 100, useGpu);
146+
}
147+
}
148+
138149
#ifndef PADDLE_ONLY_CPU
139150
TEST(Projection, conv) {
140151
const int NUM_FILTERS = 16;

paddle/math/BaseMatrix.cu

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,8 @@ int BaseMatrixT<real>::applyRow(Agg agg, BaseMatrixT& b) {
14511451
MatrixOffset offset(0, 0, 0, 0, 0, 0);
14521452
int numRows = b.height_;
14531453
int numCols = b.width_;
1454+
CHECK_EQ(height_, numRows);
1455+
CHECK_EQ(width_, 1UL);
14541456
aggregate(agg, base::unary::identity(), base::binary::second(), b, numRows,
14551457
numCols, offset, false_type(), true_type() /*aAsColVector*/);
14561458

@@ -1463,18 +1465,69 @@ int BaseMatrixT<real>::applyRow(Agg agg, Saver sv, BaseMatrixT& b) {
14631465
MatrixOffset offset(0, 0, 0, 0, 0, 0);
14641466
int numRows = b.height_;
14651467
int numCols = b.width_;
1468+
CHECK_EQ(height_, numRows);
1469+
CHECK_EQ(width_, 1UL);
14661470
aggregate(agg, base::unary::identity(), sv, b, numRows, numCols, offset,
14671471
false_type(), true_type() /*aAsColVector*/);
14681472

14691473
return 0;
14701474
}
14711475

1476+
template<>
1477+
template <class Agg>
1478+
int BaseMatrixT<real>::applyRow(
1479+
Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b) {
1480+
if (scaleDest != 0) {
1481+
applyRow(agg, base::binary::add2(scaleDest, scaleAgg), b);
1482+
} else {
1483+
applyRow(agg, base::binary::second(), b);
1484+
if (scaleAgg != 1) {
1485+
mulScalar(scaleAgg);
1486+
}
1487+
}
1488+
return 0;
1489+
}
1490+
1491+
template<>
1492+
template <class Agg, class Op, class Saver>
1493+
int BaseMatrixT<real>::applyRow(Agg agg, Op op, Saver sv,
1494+
BaseMatrixT& b, BaseMatrixT& c) {
1495+
MatrixOffset offset(0, 0, 0, 0, 0, 0);
1496+
int numRows = b.height_;
1497+
int numCols = b.width_;
1498+
CHECK_EQ(height_, numRows);
1499+
CHECK_EQ(width_, 1UL);
1500+
CHECK_EQ(c.height_, numRows);
1501+
CHECK_EQ(c.width_, numCols);
1502+
aggregate(agg, op, sv,
1503+
b, c, numRows, numCols, offset,
1504+
false_type(), true_type() /*aAsColVector*/);
1505+
return 0;
1506+
}
1507+
1508+
template<>
1509+
template <class Agg, class Op>
1510+
int BaseMatrixT<real>::applyRow(Agg agg, Op op, real scaleDest, real scaleAgg,
1511+
BaseMatrixT& b, BaseMatrixT& c) {
1512+
if (scaleDest != 0) {
1513+
applyRow(agg, op, base::binary::add2(scaleDest, scaleAgg), b, c);
1514+
} else {
1515+
applyRow(agg, op, base::binary::second(), b, c);
1516+
if (scaleAgg != 1) {
1517+
mulScalar(scaleAgg);
1518+
}
1519+
}
1520+
return 0;
1521+
}
1522+
14721523
template<>
14731524
template <class Agg>
14741525
int BaseMatrixT<real>::applyCol(Agg agg, BaseMatrixT& b) {
14751526
MatrixOffset offset(0, 0, 0, 0, 0, 0);
14761527
int numRows = b.height_;
14771528
int numCols = b.width_;
1529+
CHECK_EQ(width_, numCols);
1530+
CHECK_EQ(height_, 1UL);
14781531
aggregate(agg, base::unary::identity(), base::binary::second(), b, numRows,
14791532
numCols, offset, true_type() /*aAsRowVector*/, false_type());
14801533

@@ -1487,15 +1540,32 @@ int BaseMatrixT<real>::applyCol(Agg agg, Saver sv, BaseMatrixT& b) {
14871540
MatrixOffset offset(0, 0, 0, 0, 0, 0);
14881541
int numRows = b.height_;
14891542
int numCols = b.width_;
1543+
CHECK_EQ(width_, numCols);
1544+
CHECK_EQ(height_, 1UL);
14901545
aggregate(agg, base::unary::identity(), sv, b, numRows, numCols, offset,
14911546
true_type() /*aAsRowVector*/, false_type());
14921547

14931548
return 0;
14941549
}
14951550

14961551
template<>
1497-
void BaseMatrixT<real>::sumRows(BaseMatrixT& b) {
1498-
applyRow(aggregate::sum(), b);
1552+
template <class Agg>
1553+
int BaseMatrixT<real>::applyCol(
1554+
Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b) {
1555+
if (scaleDest != 0) {
1556+
applyCol(agg, base::binary::add2(scaleDest, scaleAgg), b);
1557+
} else {
1558+
applyCol(agg, base::binary::second(), b);
1559+
if (scaleAgg != 1) {
1560+
mulScalar(scaleAgg);
1561+
}
1562+
}
1563+
return 0;
1564+
}
1565+
1566+
template<>
1567+
void BaseMatrixT<real>::sumRows(BaseMatrixT& b, real scaleSum, real scaleDest) {
1568+
applyRow(aggregate::sum(), scaleDest, scaleSum, b);
14991569
}
15001570

15011571
template<>
@@ -1524,18 +1594,22 @@ void BaseMatrixT<real>::minCols(BaseMatrixT& b) {
15241594
}
15251595

15261596
template<>
1527-
void BaseMatrixT<real>::sumCols(BaseMatrixT& b, real scale) {
1528-
applyCol(aggregate::sum(), base::binary::add2(1.0, scale), b);
1597+
void BaseMatrixT<real>::sumCols(BaseMatrixT& b, real scaleSum, real scaleDest) {
1598+
applyCol(aggregate::sum(), scaleDest, scaleSum, b);
15291599
}
15301600

15311601
template<>
1532-
void BaseMatrixT<real>::sumOfSquares(BaseMatrixT& b, BaseMatrixT& c) {
1533-
int numRows = b.height_;
1534-
int numCols = b.width_;
1535-
MatrixOffset offset(0, 0, 0, 0, 0, 0);
1536-
aggregate(aggregate::sum(), base::binary::squaredDiff(), base::binary::add(),
1537-
b, c, numRows, numCols, offset, false_type(),
1538-
true_type() /*aAsColVector*/);
1602+
void BaseMatrixT<real>::sumOfSquaredDiffs(
1603+
BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) {
1604+
applyRow(aggregate::sum(), base::binary::squaredDiff(),
1605+
scaleDest, scaleSum, b, c);
1606+
}
1607+
1608+
template<>
1609+
void BaseMatrixT<real>::sumOfProducts(
1610+
BaseMatrixT& b, BaseMatrixT& c, real scaleSum, real scaleDest) {
1611+
applyRow(aggregate::sum(), base::binary::mul(),
1612+
scaleDest, scaleSum, b, c);
15391613
}
15401614

15411615
template class BaseMatrixT<real>;

paddle/math/BaseMatrix.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,23 @@ class BaseMatrixT {
305305
template <class Agg>
306306
int applyRow(Agg agg, BaseMatrixT& b);
307307

308+
/**
309+
* a aggregate expression that apply each row of matrix b.
310+
*
311+
* @code
312+
* for each row i & 0 <= j < b.width_, do:
313+
* dst = agg(op(b[i*ldb + j], c[i*ldc + j])
314+
* this[i] = sv(this[i], dst)
315+
* @endcode
316+
*/
317+
template <class Agg, class Op, class Saver>
318+
int applyRow(Agg agg, Op op, Saver sv, BaseMatrixT& b, BaseMatrixT& c);
319+
320+
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
321+
template <class Agg, class Op>
322+
int applyRow(Agg agg, Op op, real scaleDest, real scaleAgg,
323+
BaseMatrixT& b, BaseMatrixT& c);
324+
308325
/**
309326
* a aggregate expression that apply each row of matrix b.
310327
*
@@ -317,6 +334,10 @@ class BaseMatrixT {
317334
template <class Agg, class Saver>
318335
int applyRow(Agg agg, Saver sv, BaseMatrixT& b);
319336

337+
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
338+
template <class Agg>
339+
int applyRow(Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b);
340+
320341
/**
321342
* a aggregate expression that apply each column of matrix b.
322343
*
@@ -340,6 +361,10 @@ class BaseMatrixT {
340361
template <class Agg, class Saver>
341362
int applyCol(Agg agg, Saver sv, BaseMatrixT& b);
342363

364+
// Same as the above with the special handing of sv=add2(scaleDest, scaleAgg)
365+
template <class Agg>
366+
int applyCol(Agg agg, real scaleDest, real scaleAgg, BaseMatrixT& b);
367+
343368
bool useGpu() const { return useGpu_; }
344369

345370
const T* rowBuf(size_t row) const { return data_ + width_ * row; }
@@ -920,7 +945,9 @@ class BaseMatrixT {
920945
void addRowScale(size_t cCol, BaseMatrixT& b, BaseMatrixT& c);
921946

922947
/// calculate the sum of each row of the matrix b.
923-
void sumRows(BaseMatrixT& b);
948+
/// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij}
949+
void sumRows(BaseMatrixT& b, T scaleSum, T scaleDest);
950+
924951
/// calculate the maximum value of each row of the matrix b.
925952
void maxRows(BaseMatrixT& b);
926953
/// calculate the minimum value of each row of the matrix b.
@@ -932,10 +959,18 @@ class BaseMatrixT {
932959
void maxCols(BaseMatrixT& b);
933960
/// calculate the minimum value of each column of the matrix b.
934961
void minCols(BaseMatrixT& b);
935-
void sumCols(BaseMatrixT& b, T scale);
936962

937-
/// calculate the sum of each row of (b - c)^2.
938-
void sumOfSquares(BaseMatrixT& b, BaseMatrixT& c);
963+
/// calculate the sum of each column of the matrix b.
964+
/// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji}
965+
void sumCols(BaseMatrixT& b, T scaleSum, T scaleDest);
966+
967+
/// this_i = scaleDest * this_i + scaleSum * \sum_j (b_{ij} - c_{ij})^2
968+
void sumOfSquaredDiffs(BaseMatrixT& b, BaseMatrixT& c,
969+
T scaleSum, T scaleDest);
970+
971+
/// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij}
972+
void sumOfProducts(BaseMatrixT& b, BaseMatrixT& c,
973+
T scaleSum, T scaleDest);
939974

940975
/**
941976
* @code

0 commit comments

Comments
 (0)