Skip to content

Commit 9bc2e5a

Browse files
author
Raghuveer Devulapalli
committed
Fix bug key-value networks
1 parent 9b398b8 commit 9bc2e5a

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

src/avx512-64bit-keyvalue-networks.hpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1,
9696
zmm_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2);
9797
zmm_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2);
9898

99-
index_type index_zmm3 = vtype2::mask_mov(
100-
index_zmm2, vtype1::eq(key_zmm3, key_zmm1), index_zmm1);
101-
index_type index_zmm4 = vtype2::mask_mov(
102-
index_zmm1, vtype1::eq(key_zmm3, key_zmm1), index_zmm2);
99+
typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1);
100+
101+
index_type index_zmm3 = vtype2::mask_mov(index_zmm2, movmask, index_zmm1);
102+
index_type index_zmm4 = vtype2::mask_mov(index_zmm1, movmask, index_zmm2);
103+
104+
/* need to reverse the lower registers to keep the correct order */
105+
key_zmm4 = vtype1::permutexvar(rev_index1, key_zmm4);
106+
index_zmm4 = vtype2::permutexvar(rev_index2, index_zmm4);
103107

104108
// 2) Recursive half cleaner for each
105109
key_zmm1 = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm3, index_zmm3);
@@ -129,14 +133,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
129133
zmm_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r);
130134
zmm_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r);
131135

136+
typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]);
137+
typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]);
138+
132139
index_type index_zmm_t1 = vtype2::mask_mov(
133-
index_zmm3r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]);
140+
index_zmm3r, movmask1, index_zmm[0]);
134141
index_type index_zmm_m1 = vtype2::mask_mov(
135-
index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm3r);
142+
index_zmm[0], movmask1, index_zmm3r);
136143
index_type index_zmm_t2 = vtype2::mask_mov(
137-
index_zmm2r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]);
144+
index_zmm2r, movmask2, index_zmm[1]);
138145
index_type index_zmm_m2 = vtype2::mask_mov(
139-
index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm2r);
146+
index_zmm[1], movmask2, index_zmm2r);
140147

141148
// 2) Recursive half clearer: 16
142149
zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2);
@@ -149,14 +156,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
149156
zmm_t key_zmm2 = vtype1::min(key_zmm_t3, key_zmm_t4);
150157
zmm_t key_zmm3 = vtype1::max(key_zmm_t3, key_zmm_t4);
151158

159+
movmask1 = vtype1::eq(key_zmm0, key_zmm_t1);
160+
movmask2 = vtype1::eq(key_zmm2, key_zmm_t3);
161+
152162
index_type index_zmm0 = vtype2::mask_mov(
153-
index_zmm_t2, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t1);
163+
index_zmm_t2, movmask1, index_zmm_t1);
154164
index_type index_zmm1 = vtype2::mask_mov(
155-
index_zmm_t1, vtype1::eq(key_zmm0, key_zmm_t1), index_zmm_t2);
165+
index_zmm_t1, movmask1, index_zmm_t2);
156166
index_type index_zmm2 = vtype2::mask_mov(
157-
index_zmm_t4, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t3);
167+
index_zmm_t4, movmask2, index_zmm_t3);
158168
index_type index_zmm3 = vtype2::mask_mov(
159-
index_zmm_t3, vtype1::eq(key_zmm2, key_zmm_t3), index_zmm_t4);
169+
index_zmm_t3, movmask2, index_zmm_t4);
160170

161171
key_zmm[0] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm0, index_zmm0);
162172
key_zmm[1] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm1, index_zmm1);
@@ -197,22 +207,27 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,
197207
zmm_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r);
198208
zmm_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r);
199209

210+
typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]);
211+
typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]);
212+
typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]);
213+
typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]);
214+
200215
index_type index_zmm_t1 = vtype2::mask_mov(
201-
index_zmm7r, vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm[0]);
216+
index_zmm7r, movmask1, index_zmm[0]);
202217
index_type index_zmm_m1 = vtype2::mask_mov(
203-
index_zmm[0], vtype1::eq(key_zmm_t1, key_zmm[0]), index_zmm7r);
218+
index_zmm[0], movmask1, index_zmm7r);
204219
index_type index_zmm_t2 = vtype2::mask_mov(
205-
index_zmm6r, vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm[1]);
220+
index_zmm6r, movmask2, index_zmm[1]);
206221
index_type index_zmm_m2 = vtype2::mask_mov(
207-
index_zmm[1], vtype1::eq(key_zmm_t2, key_zmm[1]), index_zmm6r);
222+
index_zmm[1], movmask2, index_zmm6r);
208223
index_type index_zmm_t3 = vtype2::mask_mov(
209-
index_zmm5r, vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm[2]);
224+
index_zmm5r, movmask3, index_zmm[2]);
210225
index_type index_zmm_m3 = vtype2::mask_mov(
211-
index_zmm[2], vtype1::eq(key_zmm_t3, key_zmm[2]), index_zmm5r);
226+
index_zmm[2], movmask3, index_zmm5r);
212227
index_type index_zmm_t4 = vtype2::mask_mov(
213-
index_zmm4r, vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm[3]);
228+
index_zmm4r, movmask4, index_zmm[3]);
214229
index_type index_zmm_m4 = vtype2::mask_mov(
215-
index_zmm[3], vtype1::eq(key_zmm_t4, key_zmm[3]), index_zmm4r);
230+
index_zmm[3], movmask4, index_zmm4r);
216231

217232
zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4);
218233
zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3);

0 commit comments

Comments
 (0)