@@ -96,10 +96,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(zmm_t &key_zmm1,
96
96
zmm_t key_zmm3 = vtype1::min (key_zmm1, key_zmm2);
97
97
zmm_t key_zmm4 = vtype1::max (key_zmm1, key_zmm2);
98
98
99
- index_type index_zmm3 = vtype2::mask_mov (
100
- index_zmm2, vtype1::eq (key_zmm3, key_zmm1), index_zmm1);
101
- index_type index_zmm4 = vtype2::mask_mov (
102
- index_zmm1, vtype1::eq (key_zmm3, key_zmm1), index_zmm2);
99
+ typename vtype1::opmask_t movmask = vtype1::eq (key_zmm3, key_zmm1);
100
+
101
+ index_type index_zmm3 = vtype2::mask_mov (index_zmm2, movmask, index_zmm1);
102
+ index_type index_zmm4 = vtype2::mask_mov (index_zmm1, movmask, index_zmm2);
103
+
104
+ /* need to reverse the lower registers to keep the correct order */
105
+ key_zmm4 = vtype1::permutexvar (rev_index1, key_zmm4);
106
+ index_zmm4 = vtype2::permutexvar (rev_index2, index_zmm4);
103
107
104
108
// 2) Recursive half cleaner for each
105
109
key_zmm1 = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm3, index_zmm3);
@@ -129,14 +133,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
129
133
zmm_t key_zmm_m1 = vtype1::max (key_zmm[0 ], key_zmm3r);
130
134
zmm_t key_zmm_m2 = vtype1::max (key_zmm[1 ], key_zmm2r);
131
135
136
+ typename vtype1::opmask_t movmask1 = vtype1::eq (key_zmm_t1, key_zmm[0 ]);
137
+ typename vtype1::opmask_t movmask2 = vtype1::eq (key_zmm_t2, key_zmm[1 ]);
138
+
132
139
index_type index_zmm_t1 = vtype2::mask_mov (
133
- index_zmm3r, vtype1::eq (key_zmm_t1, key_zmm[ 0 ]) , index_zmm[0 ]);
140
+ index_zmm3r, movmask1 , index_zmm[0 ]);
134
141
index_type index_zmm_m1 = vtype2::mask_mov (
135
- index_zmm[0 ], vtype1::eq (key_zmm_t1, key_zmm[ 0 ]) , index_zmm3r);
142
+ index_zmm[0 ], movmask1 , index_zmm3r);
136
143
index_type index_zmm_t2 = vtype2::mask_mov (
137
- index_zmm2r, vtype1::eq (key_zmm_t2, key_zmm[ 1 ]) , index_zmm[1 ]);
144
+ index_zmm2r, movmask2 , index_zmm[1 ]);
138
145
index_type index_zmm_m2 = vtype2::mask_mov (
139
- index_zmm[1 ], vtype1::eq (key_zmm_t2, key_zmm[ 1 ]) , index_zmm2r);
146
+ index_zmm[1 ], movmask2 , index_zmm2r);
140
147
141
148
// 2) Recursive half clearer: 16
142
149
zmm_t key_zmm_t3 = vtype1::permutexvar (rev_index1, key_zmm_m2);
@@ -149,14 +156,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm,
149
156
zmm_t key_zmm2 = vtype1::min (key_zmm_t3, key_zmm_t4);
150
157
zmm_t key_zmm3 = vtype1::max (key_zmm_t3, key_zmm_t4);
151
158
159
+ movmask1 = vtype1::eq (key_zmm0, key_zmm_t1);
160
+ movmask2 = vtype1::eq (key_zmm2, key_zmm_t3);
161
+
152
162
index_type index_zmm0 = vtype2::mask_mov (
153
- index_zmm_t2, vtype1::eq (key_zmm0, key_zmm_t1) , index_zmm_t1);
163
+ index_zmm_t2, movmask1 , index_zmm_t1);
154
164
index_type index_zmm1 = vtype2::mask_mov (
155
- index_zmm_t1, vtype1::eq (key_zmm0, key_zmm_t1) , index_zmm_t2);
165
+ index_zmm_t1, movmask1 , index_zmm_t2);
156
166
index_type index_zmm2 = vtype2::mask_mov (
157
- index_zmm_t4, vtype1::eq (key_zmm2, key_zmm_t3) , index_zmm_t3);
167
+ index_zmm_t4, movmask2 , index_zmm_t3);
158
168
index_type index_zmm3 = vtype2::mask_mov (
159
- index_zmm_t3, vtype1::eq (key_zmm2, key_zmm_t3) , index_zmm_t4);
169
+ index_zmm_t3, movmask2 , index_zmm_t4);
160
170
161
171
key_zmm[0 ] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm0, index_zmm0);
162
172
key_zmm[1 ] = bitonic_merge_zmm_64bit<vtype1, vtype2>(key_zmm1, index_zmm1);
@@ -197,22 +207,27 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm,
197
207
zmm_t key_zmm_m3 = vtype1::max (key_zmm[2 ], key_zmm5r);
198
208
zmm_t key_zmm_m4 = vtype1::max (key_zmm[3 ], key_zmm4r);
199
209
210
+ typename vtype1::opmask_t movmask1 = vtype1::eq (key_zmm_t1, key_zmm[0 ]);
211
+ typename vtype1::opmask_t movmask2 = vtype1::eq (key_zmm_t2, key_zmm[1 ]);
212
+ typename vtype1::opmask_t movmask3 = vtype1::eq (key_zmm_t3, key_zmm[2 ]);
213
+ typename vtype1::opmask_t movmask4 = vtype1::eq (key_zmm_t4, key_zmm[3 ]);
214
+
200
215
index_type index_zmm_t1 = vtype2::mask_mov (
201
- index_zmm7r, vtype1::eq (key_zmm_t1, key_zmm[ 0 ]) , index_zmm[0 ]);
216
+ index_zmm7r, movmask1 , index_zmm[0 ]);
202
217
index_type index_zmm_m1 = vtype2::mask_mov (
203
- index_zmm[0 ], vtype1::eq (key_zmm_t1, key_zmm[ 0 ]) , index_zmm7r);
218
+ index_zmm[0 ], movmask1 , index_zmm7r);
204
219
index_type index_zmm_t2 = vtype2::mask_mov (
205
- index_zmm6r, vtype1::eq (key_zmm_t2, key_zmm[ 1 ]) , index_zmm[1 ]);
220
+ index_zmm6r, movmask2 , index_zmm[1 ]);
206
221
index_type index_zmm_m2 = vtype2::mask_mov (
207
- index_zmm[1 ], vtype1::eq (key_zmm_t2, key_zmm[ 1 ]) , index_zmm6r);
222
+ index_zmm[1 ], movmask2 , index_zmm6r);
208
223
index_type index_zmm_t3 = vtype2::mask_mov (
209
- index_zmm5r, vtype1::eq (key_zmm_t3, key_zmm[ 2 ]) , index_zmm[2 ]);
224
+ index_zmm5r, movmask3 , index_zmm[2 ]);
210
225
index_type index_zmm_m3 = vtype2::mask_mov (
211
- index_zmm[2 ], vtype1::eq (key_zmm_t3, key_zmm[ 2 ]) , index_zmm5r);
226
+ index_zmm[2 ], movmask3 , index_zmm5r);
212
227
index_type index_zmm_t4 = vtype2::mask_mov (
213
- index_zmm4r, vtype1::eq (key_zmm_t4, key_zmm[ 3 ]) , index_zmm[3 ]);
228
+ index_zmm4r, movmask4 , index_zmm[3 ]);
214
229
index_type index_zmm_m4 = vtype2::mask_mov (
215
- index_zmm[3 ], vtype1::eq (key_zmm_t4, key_zmm[ 3 ]) , index_zmm4r);
230
+ index_zmm[3 ], movmask4 , index_zmm4r);
216
231
217
232
zmm_t key_zmm_t5 = vtype1::permutexvar (rev_index1, key_zmm_m4);
218
233
zmm_t key_zmm_t6 = vtype1::permutexvar (rev_index1, key_zmm_m3);
0 commit comments