Skip to content

Commit f83254d

Browse files
authored
cherry-pick pyramid_hash op test=develop (#20779)(#18525) (#21562)
1 parent e228e70 commit f83254d

File tree

8 files changed

+819
-39
lines changed

8 files changed

+819
-39
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@ if (WITH_DISTRIBUTE)
4848
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
4949
endif()
5050

51-
SET(OP_ONLY_MKL "")
52-
if (NOT WITH_MKL)
53-
SET(OP_ONLY_MKL ${OP_ONLY_MKL} match_matrix_tensor_op)
54-
SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op)
51+
SET(OP_MKL_DEPS "")
52+
if (NOT WITH_MKL OR NOT WITH_AVX)
53+
SET(OP_MKL_DEPS ${OP_MKL_DEPS} match_matrix_tensor_op)
54+
SET(OP_MKL_DEPS ${OP_MKL_DEPS} var_conv_2d_op)
55+
endif()
56+
if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32)
57+
SET(OP_MKL_DEPS ${OP_MKL_DEPS} pyramid_hash_op)
5558
endif()
5659

5760
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
58-
sync_batch_norm_op multihead_matmul_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
61+
sync_batch_norm_op multihead_matmul_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
5962

6063
if (WITH_GPU)
6164
# warpctc_op needs cudnn 7 above

paddle/fluid/operators/match_matrix_tensor_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
286286
auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in;
287287
auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in;
288288
if (diff != 0.0) {
289-
sse_axpy(r_data, l_trans_diff, dim_in, diff);
290-
sse_axpy(l_trans_data, r_diff, dim_in, diff);
289+
avx_axpy(r_data, l_trans_diff, dim_in, diff);
290+
avx_axpy(l_trans_data, r_diff, dim_in, diff);
291291
}
292292
}
293293
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#define BLOOMFILTER_MAGIC_NUM_NEW 17070416
17+
18+
#include <inttypes.h>
19+
#include <stdlib.h>
20+
21+
#include <stdio.h>
22+
#include <string.h>
23+
24+
#include <unistd.h>
25+
26+
namespace paddle {
27+
namespace operators {
28+
namespace math {
29+
30+
#pragma pack(4)
31+
struct bloomfilter {
32+
uint64_t magic_num;
33+
uint64_t m;
34+
uint64_t k;
35+
uint64_t count;
36+
unsigned char bit_vector[1];
37+
};
38+
int bloomfilter_get(const struct bloomfilter *bloomfilter, const void *key,
39+
size_t len);
40+
int bloomfilter_check(struct bloomfilter *filter);
41+
42+
#define bit_get(v, n) ((v)[(n) >> 3] & (0x1 << (0x7 - ((n)&0x7))))
43+
#define ROTL64(x, r) (((x) << (r)) | ((x) >> (64 - (r))))
44+
#define BIG_CONSTANT(x) (x##LLU)
45+
46+
uint64_t fmix64(uint64_t k) {
47+
k ^= k >> 33;
48+
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
49+
k ^= k >> 33;
50+
k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
51+
k ^= k >> 33;
52+
return k;
53+
}
54+
55+
void murmurhash3_x64_128(const void *key, const int len, const uint32_t seed,
56+
void *out) {
57+
const uint8_t *data = (const uint8_t *)key;
58+
const int nblocks = len / 16;
59+
60+
uint64_t h1 = seed;
61+
uint64_t h2 = seed;
62+
int i = 0;
63+
64+
const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
65+
const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
66+
67+
//----------
68+
// body
69+
70+
const uint64_t *blocks = (const uint64_t *)(data);
71+
72+
uint64_t k1;
73+
uint64_t k2;
74+
75+
for (i = 0; i < nblocks; i++) {
76+
k1 = blocks[i * 2 + 0];
77+
k2 = blocks[i * 2 + 1];
78+
79+
k1 *= c1;
80+
k1 = ROTL64(k1, 31);
81+
k1 *= c2;
82+
h1 ^= k1;
83+
84+
h1 = ROTL64(h1, 27);
85+
h1 += h2;
86+
h1 = h1 * 5 + 0x52dce729;
87+
88+
k2 *= c2;
89+
k2 = ROTL64(k2, 33);
90+
k2 *= c1;
91+
h2 ^= k2;
92+
93+
h2 = ROTL64(h2, 31);
94+
h2 += h1;
95+
h2 = h2 * 5 + 0x38495ab5;
96+
}
97+
98+
//----------
99+
// tail
100+
101+
const uint8_t *tail = (const uint8_t *)(data + nblocks * 16);
102+
uint64_t nk1 = 0;
103+
uint64_t nk2 = 0;
104+
105+
uint64_t tail0_64 = *(uint64_t *)(tail); // NOLINT
106+
uint64_t tail_64 = *(uint64_t *)(tail + 8); // NOLINT
107+
uint64_t mask0 = 0xffffffffffffffff;
108+
uint64_t mask = 0x00ffffffffffffff;
109+
110+
int flag = len & 15;
111+
if (flag && flag <= 8) {
112+
tail0_64 &= (mask0 >> ((8 - flag) << 3));
113+
} else if (flag > 8) {
114+
tail_64 &= (mask >> ((15 - flag) << 3));
115+
nk2 ^= tail_64;
116+
nk2 *= c2;
117+
nk2 = ROTL64(nk2, 33);
118+
nk2 *= c1;
119+
h2 ^= nk2;
120+
}
121+
122+
if (flag) {
123+
nk1 ^= tail0_64;
124+
nk1 *= c1;
125+
nk1 = ROTL64(nk1, 31);
126+
nk1 *= c2;
127+
h1 ^= nk1;
128+
}
129+
130+
//----------
131+
// finalization
132+
133+
h1 ^= len;
134+
h2 ^= len;
135+
136+
h1 += h2;
137+
h2 += h1;
138+
139+
h1 = fmix64(h1);
140+
h2 = fmix64(h2);
141+
142+
h1 += h2;
143+
h2 += h1;
144+
145+
reinterpret_cast<uint64_t *>(out)[0] = h1;
146+
reinterpret_cast<uint64_t *>(out)[1] = h2;
147+
}
148+
149+
int bloomfilter_check(struct bloomfilter *filter) {
150+
if (filter->magic_num == BLOOMFILTER_MAGIC_NUM_NEW) {
151+
return 1;
152+
} else {
153+
fprintf(stderr, "error magic_num %ld\n", filter->magic_num);
154+
return 0;
155+
}
156+
}
157+
158+
int bloomfilter_get(const struct bloomfilter *bloomfilter, const void *key,
159+
size_t len) {
160+
uint32_t i;
161+
uint64_t result[2];
162+
163+
for (i = 0; i < bloomfilter->k; i++) {
164+
murmurhash3_x64_128(key, len, i, &result);
165+
result[0] %= bloomfilter->m;
166+
result[1] %= bloomfilter->m;
167+
if (!bit_get(bloomfilter->bit_vector, result[0])) {
168+
return 0;
169+
}
170+
if (!bit_get(bloomfilter->bit_vector, result[1])) {
171+
return 0;
172+
}
173+
}
174+
return 1;
175+
}
176+
177+
} // namespace math
178+
} // namespace operators
179+
} // namespace paddle

0 commit comments

Comments
 (0)