Skip to content

Commit 25a0193

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into iou_sim
2 parents a05d25c + c648244 commit 25a0193

File tree

11 files changed

+238
-44
lines changed

11 files changed

+238
-44
lines changed

doc/api/v2/fluid/layers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ dynamic_lstm
1818
.. autofunction:: paddle.v2.fluid.layers.dynamic_lstm
1919
:noindex:
2020

21+
dynamic_gru
22+
-----------
23+
.. autofunction:: paddle.v2.fluid.layers.dynamic_gru
24+
:noindex:
25+
2126
data
2227
----
2328
.. autofunction:: paddle.v2.fluid.layers.data

doc/getstarted/build_and_install/docker_install_cn.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525

2626
.. code-block:: bash
2727
28-
docker pull docker.paddlepaddle.org/paddle
28+
docker pull docker.paddlepaddlehub.com/paddle
2929
3030
下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像:
3131

3232
.. code-block:: bash
3333
3434
docker pull paddlepaddle/paddle:latest-gpu
35-
docker pull docker.paddlepaddle.org/paddle:latest-gpu
35+
docker pull docker.paddlepaddlehub.com/paddle:latest-gpu
3636
3737
选择下载使用不同的BLAS库的Docker镜像:
3838

@@ -49,7 +49,7 @@
4949
5050
docker pull paddlepaddle/paddle:[tag]
5151
# 比如:
52-
docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu
52+
docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu
5353
5454
.. _docker_run:
5555

doc/getstarted/build_and_install/docker_install_en.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ For users in China, we provide a faster mirror:
2626

2727
.. code-block:: bash
2828
29-
docker pull docker.paddlepaddle.org/paddle
29+
docker pull docker.paddlepaddlehub.com/paddle
3030
3131
Download GPU version (cuda8.0_cudnn5_avx_mkl) images:
3232

3333
.. code-block:: bash
3434
3535
docker pull paddlepaddle/paddle:latest-gpu
36-
docker pull docker.paddlepaddle.org/paddle:latest-gpu
36+
docker pull docker.paddlepaddlehub.com/paddle:latest-gpu
3737
3838
Choose between different BLAS version:
3939

@@ -53,7 +53,7 @@ and run:
5353
5454
docker pull paddlepaddle/paddle:[tag]
5555
# i.e.
56-
docker pull docker.paddlepaddle.org/paddle:0.10.0-gpu
56+
docker pull docker.paddlepaddlehub.com/paddle:0.11.0-gpu
5757
5858
.. _docker_run:
5959

paddle/framework/variable_test.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
/*
16-
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
17-
Licensed under the Apache License, Version 2.0 (the "License");
18-
you may not use this file except in compliance with the License.
19-
You may obtain a copy of the License at
20-
http://www.apache.org/licenses/LICENSE-2.0
21-
Unless required by applicable law or agreed to in writing, software
22-
distributed under the License is distributed on an "AS IS" BASIS,
23-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24-
See the License for the specific language governing permissions and
25-
limitations under the License.
26-
*/
27-
2815
#include <memory>
2916
#include <string>
3017

paddle/operators/bipartite_match_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace operators {
2121
using Tensor = framework::Tensor;
2222
using LoDTensor = framework::LoDTensor;
2323

24-
constexpr char kEPS = 1e-6;
25-
2624
class BipartiteMatchOp : public framework::OperatorWithKernel {
2725
public:
2826
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
4644
// The match_dist must be initialized to 0 at first.
4745
void BipartiteMatch(const Tensor& dist, int* match_indices,
4846
T* match_dist) const {
47+
constexpr T kEPS = static_cast<T>(1e-6);
4948
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
5049
int64_t row = dist.dims()[0];
5150
int64_t col = dist.dims()[1];

paddle/operators/nce_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
124124
"This attribute only be used in unitest. Classes "
125125
"in this list wiil be used as negative classes "
126126
"for every samples. Under normal conditions, "
127-
"user should avoid setting this attribute.");
127+
"user should avoid setting this attribute.")
128+
.SetDefault({});
128129
AddComment(R"DOC(
129130
Compute and return the noise-contrastive estimation training loss.
130131
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).

paddle/operators/nce_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
197197
// get d_x
198198
auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
199199
if (d_x != nullptr) {
200-
d_x->mutable_data<T>(context.GetPlace());
200+
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
201+
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
201202
auto d_x_matrix = EigenMatrix<T>::From(*d_x);
202203
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
203204
for (int64_t i = 0; i < sample_labels->numel(); ++i) {

python/paddle/v2/dataset/wmt16.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@ def get_dict(lang, dict_size, reverse=False):
305305

306306
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
307307
"wmt16/%s_%d.dict" % (lang, dict_size))
308-
assert (os.path.exists(dict_path), "Word dictionary does not exist. "
309-
"Please invoke paddle.dataset.wmt16.train/test/validation "
310-
"first to build the dictionary.")
308+
assert os.path.exists(dict_path), "Word dictionary does not exist. "
309+
"Please invoke paddle.dataset.wmt16.train/test/validation first "
310+
"to build the dictionary."
311311
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
312312
return __load_dict(tar_file, dict_size, lang, reverse)
313313

python/paddle/v2/fluid/layers/nn.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from ..initializer import Normal, Constant
2020
from ..framework import Variable
2121
from ..param_attr import ParamAttr
22+
from layer_function_generator import autodoc
2223
from tensor import concat
2324

2425
__all__ = [
2526
'fc',
2627
'embedding',
2728
'dynamic_lstm',
29+
'dynamic_gru',
2830
'gru_unit',
2931
'linear_chain_crf',
3032
'crf_decoding',
@@ -57,6 +59,7 @@
5759
'warpctc',
5860
'sequence_reshape',
5961
'transpose',
62+
'nce',
6063
]
6164

6265

@@ -366,6 +369,113 @@ def dynamic_lstm(input,
366369
return hidden, cell
367370

368371

372+
def dynamic_gru(input,
373+
size,
374+
param_attr=None,
375+
bias_attr=None,
376+
is_reverse=False,
377+
gate_activation='sigmoid',
378+
candidate_activation='tanh',
379+
h_0=None):
380+
"""
381+
**Dynamic GRU Layer**
382+
383+
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
384+
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_
385+
386+
The formula is as follows:
387+
388+
.. math::
389+
390+
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
391+
392+
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
393+
394+
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
395+
396+
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
397+
398+
The :math:`\odot` is the element-wise product of the vectors. :math:`act_g`
399+
is the update gate and reset gate activation function and :math:`sigmoid`
400+
is usually used for it. :math:`act_c` is the activation function for
401+
candidate hidden state and :math:`tanh` is usually used for it.
402+
403+
Note that these :math:`W_{ux}x_{t}, W_{rx}x_{t}, W_{cx}x_{t}` operations on
404+
the input :math:`x_{t}` are NOT included in this operator. Users can choose
405+
to use fully-connect layer before GRU layer.
406+
407+
Args:
408+
input(Variable): The input of dynamic_gru layer, which supports
409+
variable-time length input sequence. The underlying tensor in this
410+
Variable is a matrix with shape :math:`(T \\times 3D)`, where
411+
:math:`T` is the total time steps in this mini-batch, :math:`D`
412+
is the hidden size.
413+
size(int): The dimension of the gru cell.
414+
param_attr(ParamAttr|None): The parameter attribute for the learnable
415+
hidden-hidden weight matrix. Note:
416+
417+
- The shape of the weight matrix is :math:`(T \\times 3D)`, where
418+
:math:`D` is the hidden size.
419+
- All elements in the weight matrix can be divided into two parts.
420+
The first part are weights of the update gate and reset gate with
421+
shape :math:`(D \\times 2D)`, and the second part are weights for
422+
candidate hidden state with shape :math:`(D \\times D)`.
423+
bias_attr(ParamAttr): The parameter attribute for learnable the
424+
hidden-hidden bias.
425+
is_reverse(bool): Whether to compute reversed GRU, default
426+
:attr:`False`.
427+
gate_activation(str): The activation for update gate and reset gate.
428+
Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid".
429+
activation(str): The activation for candidate hidden state.
430+
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
431+
432+
Returns:
433+
Variable: The hidden state of GRU. The shape is (T \\times D), and lod \
434+
is the same with the input.
435+
436+
Examples:
437+
.. code-block:: python
438+
439+
hidden_dim = 512
440+
x = fluid.layers.fc(input=data, size=hidden_dim * 3)
441+
hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim)
442+
"""
443+
444+
helper = LayerHelper('gru', **locals())
445+
dtype = helper.input_dtype()
446+
447+
weight = helper.create_parameter(
448+
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
449+
bias = helper.create_parameter(
450+
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True)
451+
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
452+
if h_0 != None:
453+
assert h_0.shape == (
454+
size, size), 'The shape of h0 should be(%d, %d)' % (size, size)
455+
inputs['h0'] = h_0
456+
457+
hidden = helper.create_tmp_variable(dtype)
458+
batch_gate = helper.create_tmp_variable(dtype)
459+
batch_reset_hidden_prev = helper.create_tmp_variable(dtype)
460+
batch_hidden = helper.create_tmp_variable(dtype)
461+
462+
helper.append_op(
463+
type='gru',
464+
inputs=inputs,
465+
outputs={
466+
'Hidden': hidden,
467+
'BatchGate': batch_gate,
468+
'BatchResetHiddenPrev': batch_reset_hidden_prev,
469+
'BatchHidden': batch_hidden
470+
},
471+
attrs={
472+
'is_reverse': is_reverse,
473+
'gate_activation': gate_activation,
474+
'activation': candidate_activation
475+
})
476+
return hidden
477+
478+
369479
def gru_unit(input,
370480
hidden,
371481
size,
@@ -2190,6 +2300,61 @@ def sequence_reshape(input, new_dim):
21902300
return out
21912301

21922302

2303+
@autodoc()
2304+
def nce(input,
2305+
label,
2306+
num_total_classes,
2307+
sample_weight=None,
2308+
param_attr=None,
2309+
bias_attr=None,
2310+
num_neg_samples=None):
2311+
helper = LayerHelper('nce', **locals())
2312+
assert isinstance(input, Variable)
2313+
dim = input.shape[1]
2314+
assert isinstance(label, Variable)
2315+
num_true_class = label.shape[1]
2316+
w = helper.create_parameter(
2317+
attr=helper.param_attr,
2318+
shape=[num_total_classes, dim],
2319+
is_bias=False,
2320+
dtype=input.dtype)
2321+
b = helper.create_parameter(
2322+
attr=helper.bias_attr,
2323+
shape=[num_total_classes, 1],
2324+
is_bias=True,
2325+
dtype=input.dtype)
2326+
cost = helper.create_tmp_variable(dtype=input.dtype)
2327+
sample_logits = helper.create_tmp_variable(dtype=input.dtype)
2328+
sample_labels = helper.create_tmp_variable(dtype=label.dtype)
2329+
2330+
if num_neg_samples is None:
2331+
num_neg_samples = 10
2332+
else:
2333+
num_neg_samples = int(num_neg_samples)
2334+
2335+
attrs = {
2336+
'num_total_classes': int(num_total_classes),
2337+
'num_neg_samples': num_neg_samples
2338+
}
2339+
2340+
helper.append_op(
2341+
type='nce',
2342+
inputs={
2343+
'Input': input,
2344+
'Label': label,
2345+
'Weight': w,
2346+
'Bias': b,
2347+
'SampleWeight': sample_weight if sample_weight is not None else []
2348+
},
2349+
outputs={
2350+
'Cost': cost,
2351+
'SampleLogits': sample_logits,
2352+
'SampleLabels': sample_labels
2353+
},
2354+
attrs=attrs)
2355+
return cost / (num_neg_samples + 1)
2356+
2357+
21932358
def transpose(x, perm, name=None):
21942359
"""
21952360
**transpose Layer**

0 commit comments

Comments
 (0)