Skip to content

Commit 0c8c0d9

Browse files
authored
fix macunittest (#13434)
1 parent 1e44201 commit 0c8c0d9

File tree

4 files changed

+73
-80
lines changed

4 files changed

+73
-80
lines changed
Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,18 @@
11
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129
See the License for the specific language governing permissions and
1310
limitations under the License. */
1411

1512
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
16-
#include "paddle/fluid/operators/math/cpu_vec.h"
17-
#include "paddle/fluid/platform/cpu_info.h"
18-
#ifdef __AVX__
19-
#include <immintrin.h>
20-
#endif
2113

2214
namespace paddle {
2315
namespace operators {
24-
namespace math {
25-
26-
// TODO(TJ): ugly workaround, clean me
27-
template <typename T>
28-
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
29-
// gates: W_ch, W_ih, W_fh, W_oh
30-
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
31-
vec_tanh<T, platform::jit::avx>(8, gates, gates);
32-
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
33-
const T min = SIGMOID_THRESHOLD_MIN;
34-
const T max = SIGMOID_THRESHOLD_MAX;
35-
for (int d = 0; d < 8; ++d) {
36-
// C_t = C_t-1 * fgated + cand_gated * igated
37-
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
38-
// H_t = act_cell(C_t) * ogated
39-
T tmp = ct[d] * 2;
40-
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
41-
vec_exp<T>(1, &tmp, &tmp);
42-
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
43-
ht[d] = tmp * o[d];
44-
}
45-
}
46-
47-
#ifdef __AVX__
48-
namespace detail {
49-
namespace forward {
50-
namespace avx {
51-
__m256 Sigmoid(const __m256 a);
52-
__m256 Tanh(const __m256 a);
53-
} // namespace avx
54-
} // namespace forward
55-
} // namespace detail
56-
57-
template <>
58-
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
59-
float* ht) {
60-
namespace act = detail::forward::avx;
61-
// gates: W_ch, W_ih, W_fh, W_oh
62-
__m256 c, i, f, o;
63-
c = _mm256_loadu_ps(gates);
64-
i = _mm256_loadu_ps(gates + 8);
65-
f = _mm256_loadu_ps(gates + 16);
66-
o = _mm256_loadu_ps(gates + 24);
67-
68-
/* C_t = C_t-1 * fgated + cand_gated * igated*/
69-
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
70-
i = _mm256_loadu_ps(ct_1);
71-
f = _mm256_mul_ps(i, act::Sigmoid(f));
72-
f = _mm256_add_ps(c, f);
73-
_mm256_storeu_ps(ct, f);
74-
75-
/* H_t = act_cell(C_t) * ogated */
76-
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
77-
_mm256_storeu_ps(ht, o);
78-
}
79-
#endif
80-
81-
template void lstm_compute_ctht<float>(float* gates, const float* ct_1,
82-
float* ct, float* ht);
83-
template void lstm_compute_ctht<double>(double* gates, const double* ct_1,
84-
double* ct, double* ht);
85-
86-
} // namespace math
16+
namespace math {} // namespace math
8717
} // namespace operators
8818
} // namespace paddle

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -14,14 +11,70 @@ limitations under the License. */
1411

1512
#pragma once
1613
#include <string>
14+
#include "paddle/fluid/operators/math/cpu_vec.h"
15+
#include "paddle/fluid/platform/cpu_info.h"
16+
#ifdef __AVX__
17+
#include <immintrin.h>
18+
#endif
1719

1820
namespace paddle {
1921
namespace operators {
2022
namespace math {
2123

2224
// TODO(TJ): ugly workaround, clean me
2325
template <typename T>
24-
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht);
26+
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
27+
// gates: W_ch, W_ih, W_fh, W_oh
28+
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
29+
vec_tanh<T, platform::jit::avx>(8, gates, gates);
30+
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
31+
const T min = SIGMOID_THRESHOLD_MIN;
32+
const T max = SIGMOID_THRESHOLD_MAX;
33+
for (int d = 0; d < 8; ++d) {
34+
// C_t = C_t-1 * fgated + cand_gated * igated
35+
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
36+
// H_t = act_cell(C_t) * ogated
37+
T tmp = ct[d] * 2;
38+
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
39+
vec_exp<T>(1, &tmp, &tmp);
40+
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
41+
ht[d] = tmp * o[d];
42+
}
43+
}
44+
45+
#ifdef __AVX__
46+
namespace detail {
47+
namespace forward {
48+
namespace avx {
49+
__m256 Sigmoid(const __m256 a);
50+
__m256 Tanh(const __m256 a);
51+
} // namespace avx
52+
} // namespace forward
53+
} // namespace detail
54+
55+
template <>
56+
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
57+
float* ht) {
58+
namespace act = detail::forward::avx;
59+
// gates: W_ch, W_ih, W_fh, W_oh
60+
__m256 c, i, f, o;
61+
c = _mm256_loadu_ps(gates);
62+
i = _mm256_loadu_ps(gates + 8);
63+
f = _mm256_loadu_ps(gates + 16);
64+
o = _mm256_loadu_ps(gates + 24);
65+
66+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
67+
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
68+
i = _mm256_loadu_ps(ct_1);
69+
f = _mm256_mul_ps(i, act::Sigmoid(f));
70+
f = _mm256_add_ps(c, f);
71+
_mm256_storeu_ps(ct, f);
72+
73+
/* H_t = act_cell(C_t) * ogated */
74+
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
75+
_mm256_storeu_ps(ht, o);
76+
}
77+
#endif
2578

2679
} // namespace math
2780
} // namespace operators

python/paddle/fluid/tests/unittests/test_desc_clone.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
109109
return t
110110

111111

112+
from paddle.fluid.transpiler.details import op_to_code
113+
114+
112115
def operator_equal(a, b):
116+
if op_to_code(a) != op_to_code(b):
117+
raise ValueError("In operator_equal not equal\n")
118+
113119
for k, v in six.iteritems(a.__dict__):
114120
if isinstance(v, fluid.framework.Program) or \
115121
isinstance(v, fluid.framework.Block):
116122
continue
117123

118124
elif isinstance(v, core.OpDesc):
119-
if v.serialize_to_string() != b.__dict__[k].serialize_to_string():
120-
raise ValueError("In operator_equal not equal:{0}\n".format(k))
125+
continue
121126

122127
elif isinstance(v, collections.OrderedDict):
123128
v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0])

python/paddle/fluid/transpiler/details/program_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,32 @@ def op_to_code(op):
113113
inputs_str += ", "
114114
inputs_str += "}"
115115

116+
attr_names = sorted(op.attr_names)
116117
attrs_str = ""
117-
for i in range(0, len(op.attr_names)):
118-
name = op.attr_names[i]
118+
for i in range(0, len(attr_names)):
119+
name = attr_names[i]
119120

120121
attr_type = op.desc.attr_type(name)
121122
if attr_type == core.AttrType.BLOCK:
122123
a = "{name} = block[{value}]".format(
123124
name=name, type=attr_type, value=op.block_attr_id(name))
124125
attrs_str += a
126+
if i != len(attr_names) - 1:
127+
attrs_str += ", "
125128
continue
126129

127130
if attr_type == core.AttrType.BLOCKS:
128131
a = "{name} = blocks{value}".format(
129132
name=name, type=attr_type, value=op.blocks_attr_ids(name))
130133
attrs_str += a
134+
if i != len(attr_names) - 1:
135+
attrs_str += ", "
131136
continue
132137

133138
a = "{name} = {value}".format(
134139
name=name, type=attr_type, value=op.desc.attr(name))
135140
attrs_str += a
136-
if i != len(op.attr_names) - 1:
141+
if i != len(attr_names) - 1:
137142
attrs_str += ", "
138143

139144
if outputs_str != "{}":

0 commit comments

Comments
 (0)