Skip to content

Commit 479c861

Browse files
authored
Merge pull request #7726 from lcy-seso/fix_rendering_error_of_transpose_op
fix rendering error of transpose operator and add wrapper.
2 parents 6d2cfe9 + dcb5a1e commit 479c861

File tree

5 files changed

+132
-61
lines changed

5 files changed

+132
-61
lines changed

paddle/operators/transpose_op.cc

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,44 +59,39 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
5959
: OpProtoAndCheckerMaker(proto, op_checker) {
6060
AddInput(
6161
"X",
62-
"(Tensor)The input tensor, tensors with rank at most 6 are supported");
63-
AddOutput("Out", "(Tensor)The output tensor");
62+
"(Tensor) The input tensor, tensors with rank up to 6 are supported.");
63+
AddOutput("Out", "(Tensor)The output tensor.");
6464
AddAttr<std::vector<int>>(
6565
"axis",
66-
"(vector<int>)A list of values, and the size of the list should be "
67-
"the same with the input tensor rank, the tensor will "
68-
"permute the axes according the the values given");
66+
"(vector<int>) A list of values, and the size of the list should be "
67+
"the same with the input tensor rank. This operator permutes the input "
68+
"tensor's axes according to the values given.");
6969
AddComment(R"DOC(
7070
Transpose Operator.
7171
72-
The input tensor will be permuted according to the axis values given.
73-
The op functions is similar to how numpy.transpose works in python.
72+
The input tensor will be permuted according to the axes given.
73+
The behavior of this operator is similar to how `numpy.transpose` works.
7474
75-
For example:
75+
- suppose the input `X` is a 2-D tensor:
76+
$$
77+
X = \begin{pmatrix}
78+
0 &1 &2 \\
79+
3 &4 &5
80+
\end{pmatrix}$$
7681
77-
.. code-block:: text
82+
the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis)
7883
79-
input = numpy.arange(6).reshape((2,3))
84+
then the output $Y$ is:
8085
81-
the input is:
86+
$$
87+
Y = \begin{pmatrix}
88+
0 &3 \\
89+
1 &4 \\
90+
2 &5
91+
\end{pmatrix}$$
8292
83-
array([[0, 1, 2],
84-
[3, 4, 5]])
85-
86-
given axis is:
87-
88-
[1, 0]
89-
90-
output = input.transpose(axis)
91-
92-
then the output is:
93-
94-
array([[0, 3],
95-
[1, 4],
96-
[2, 5]])
97-
98-
So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1},
99-
the output tensor shape will be (N, H, W, C)
93+
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
94+
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
10095
10196
)DOC");
10297
}

python/paddle/v2/dataset/wmt16.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
171171
callable: The train reader.
172172
"""
173173

174-
assert (src_lang in ["en", "de"], ("An error language type. Only support: "
175-
"en (for English); de(for Germany)"))
174+
if src_lang not in ["en", "de"]:
175+
raise ValueError("An error language type. Only support: "
176+
"en (for English); de(for Germany).")
176177
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
177178
src_lang)
178179

@@ -218,9 +219,9 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
218219
callable: The test reader.
219220
"""
220221

221-
assert (src_lang in ["en", "de"],
222-
("An error language type. "
223-
"Only support: en (for English); de(for Germany)"))
222+
if src_lang not in ["en", "de"]:
223+
raise ValueError("An error language type. "
224+
"Only support: en (for English); de(for Germany).")
224225

225226
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
226227
src_lang)
@@ -266,9 +267,9 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
266267
Returns:
267268
callable: The validation reader.
268269
"""
269-
assert (src_lang in ["en", "de"],
270-
("An error language type. "
271-
"Only support: en (for English); de(for Germany)"))
270+
if src_lang not in ["en", "de"]:
271+
raise ValueError("An error language type. "
272+
"Only support: en (for English); de(for Germany).")
272273
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
273274
src_lang)
274275

python/paddle/v2/fluid/layers/nn.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,41 @@
2222
from tensor import concat
2323

2424
__all__ = [
25-
'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf',
26-
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
27-
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
28-
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
29-
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
30-
'sequence_first_step', 'sequence_last_step', 'dropout', 'split',
31-
'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc',
32-
'sequence_reshape'
25+
'fc',
26+
'embedding',
27+
'dynamic_lstm',
28+
'gru_unit',
29+
'linear_chain_crf',
30+
'crf_decoding',
31+
'cos_sim',
32+
'cross_entropy',
33+
'square_error_cost',
34+
'accuracy',
35+
'chunk_eval',
36+
'sequence_conv',
37+
'conv2d',
38+
'sequence_pool',
39+
'pool2d',
40+
'batch_norm',
41+
'beam_search_decode',
42+
'conv2d_transpose',
43+
'sequence_expand',
44+
'lstm_unit',
45+
'reduce_sum',
46+
'reduce_mean',
47+
'reduce_max',
48+
'reduce_min',
49+
'sequence_first_step',
50+
'sequence_last_step',
51+
'dropout',
52+
'split',
53+
'ctc_greedy_decoder',
54+
'edit_distance',
55+
'l2_normalize',
56+
'matmul',
57+
'warpctc',
58+
'sequence_reshape',
59+
'transpose',
3360
]
3461

3562

@@ -44,14 +71,14 @@ def fc(input,
4471
**Fully Connected Layer**
4572
4673
The fully connected layer can take multiple tensors as its inputs. It
47-
creates a variable (one for each input tensor) called weights for each input
48-
tensor, which represents a fully connected weight matrix from each input
49-
unit to each output unit. The fully connected layer multiplies each input
50-
tensor with its coresponding weight to produce an output Tensor. If
51-
multiple input tensors are given, the results of multiple multiplications
52-
will be sumed up. If bias_attr is not None, a biases variable will be
53-
created and added to the output. Finally, if activation is not None,
54-
it will be applied to the output as well.
74+
creates a variable (one for each input tensor) called weights for each
75+
input tensor, which represents a fully connected weight matrix from
76+
each input unit to each output unit. The fully connected layer
77+
multiplies each input tensor with its coresponding weight to produce
78+
an output Tensor. If multiple input tensors are given, the results of
79+
multiple multiplications will be sumed up. If bias_attr is not None,
80+
a biases variable will be created and added to the output. Finally,
81+
if activation is not None, it will be applied to the output as well.
5582
5683
This process can be formulated as follows:
5784
@@ -1814,11 +1841,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
18141841
18151842
- If both are 2-D, they are multiplied like conventional matrices.
18161843
- If either is n-D, it is treated as a stack of matrices residing in the
1817-
last two dimensions and a batched matrix multiply supporting broadcast
1844+
last two dimensions and a batched matrix multiply supporting broadcast
18181845
applies on the two tensors.
18191846
1820-
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
1821-
nontransposed, the prepended or appended dimension :math:`1` will be
1847+
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
1848+
nontransposed, the prepended or appended dimension :math:`1` will be
18221849
removed after matrix multiplication.
18231850
18241851
Args:
@@ -2112,3 +2139,41 @@ def sequence_reshape(input, new_dim):
21122139
outputs={'Out': [out]},
21132140
attrs={'new_dim': new_dim})
21142141
return out
2142+
2143+
2144+
def transpose(x, perm, name=None):
2145+
"""
2146+
**transpose Layer**
2147+
2148+
Permute the dimensions of `input` according to `perm`.
2149+
2150+
The `i`-th dimension of the returned tensor will correspond to the
2151+
perm[i]-th dimension of `input`.
2152+
2153+
Args:
2154+
input (Variable): (Tensor), A Tensor.
2155+
perm (list): A permutation of the dimensions of `input`.
2156+
2157+
Returns:
2158+
Variable: A transposed Tensor.
2159+
2160+
Examples:
2161+
.. code-block:: python
2162+
2163+
x = fluid.layers.data(name='x', shape=[5, 10, 15], dtype='float32')
2164+
x_transposed = layers.transpose(x, perm=[1, 0, 2])
2165+
"""
2166+
2167+
if len(perm) != len(x.shape):
2168+
raise ValueError(
2169+
"Input(perm) is the permutation of dimensions of Input(input). "
2170+
"It's length shoud be equal to Input(input)'s rank.")
2171+
2172+
helper = LayerHelper('transpose', **locals())
2173+
out = helper.create_tmp_variable(x.dtype)
2174+
helper.append_op(
2175+
type='transpose',
2176+
inputs={'X': [x]},
2177+
outputs={'Out': [out]},
2178+
attrs={'axis': perm})
2179+
return out

python/paddle/v2/fluid/layers/ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,20 @@
4545
]
4646

4747
__all__ = [
48-
'mean', 'mul', 'reshape', 'scale', 'transpose',
49-
'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
50-
'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min',
51-
'clip', 'clip_by_norm', 'sequence_softmax'
48+
'mean',
49+
'mul',
50+
'reshape',
51+
'scale',
52+
'sigmoid_cross_entropy_with_logits',
53+
'elementwise_add',
54+
'elementwise_div',
55+
'elementwise_sub',
56+
'elementwise_mul',
57+
'elementwise_max',
58+
'elementwise_min',
59+
'clip',
60+
'clip_by_norm',
61+
'sequence_softmax',
5262
] + __activations__
5363

5464
for _OP in set(__all__):

python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50):
6565

6666
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
6767
emb = fluid.layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim])
68-
emb = fluid.layers.transpose(x=emb, axis=[1, 0, 2])
68+
emb = fluid.layers.transpose(x=emb, perm=[1, 0, 2])
6969

7070
c_pre_init = fluid.layers.fill_constant(
7171
dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0)
7272
c_pre_init.stop_gradient = False
7373
layer_1_out = lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim)
74-
layer_1_out = fluid.layers.transpose(x=layer_1_out, axis=[1, 0, 2])
74+
layer_1_out = fluid.layers.transpose(x=layer_1_out, perm=[1, 0, 2])
7575

7676
prediction = fluid.layers.fc(input=layer_1_out,
7777
size=class_dim,

0 commit comments

Comments
 (0)