@@ -99,38 +99,11 @@ struct avx512_16bit_swizzle_ops {
99
99
__m512i v = vtype::cast_to (reg);
100
100
101
101
if constexpr (scale == 2 ) {
102
- __m512i mask = _mm512_set_epi16 (30 ,
103
- 31 ,
104
- 28 ,
105
- 29 ,
106
- 26 ,
107
- 27 ,
108
- 24 ,
109
- 25 ,
110
- 22 ,
111
- 23 ,
112
- 20 ,
113
- 21 ,
114
- 18 ,
115
- 19 ,
116
- 16 ,
117
- 17 ,
118
- 14 ,
119
- 15 ,
120
- 12 ,
121
- 13 ,
122
- 10 ,
123
- 11 ,
124
- 8 ,
125
- 9 ,
126
- 6 ,
127
- 7 ,
128
- 4 ,
129
- 5 ,
130
- 2 ,
131
- 3 ,
132
- 0 ,
133
- 1 );
102
+ std::vector<uint16_t > arr
103
+ = {1 , 0 , 3 , 2 , 5 , 4 , 7 , 6 , 9 , 8 , 11 ,
104
+ 10 , 13 , 12 , 15 , 14 , 17 , 16 , 19 , 18 , 21 , 20 ,
105
+ 23 , 22 , 25 , 24 , 27 , 26 , 29 , 28 , 31 , 30 };
106
+ __m512i mask = _mm512_loadu_si512 (arr.data ());
134
107
v = _mm512_permutexvar_epi16 (mask, v);
135
108
}
136
109
else if constexpr (scale == 4 ) {
@@ -160,108 +133,27 @@ struct avx512_16bit_swizzle_ops {
160
133
161
134
if constexpr (scale == 2 ) { return swap_n<vtype, 2 >(reg); }
162
135
else if constexpr (scale == 4 ) {
163
- __m512i mask = _mm512_set_epi16 (28 ,
164
- 29 ,
165
- 30 ,
166
- 31 ,
167
- 24 ,
168
- 25 ,
169
- 26 ,
170
- 27 ,
171
- 20 ,
172
- 21 ,
173
- 22 ,
174
- 23 ,
175
- 16 ,
176
- 17 ,
177
- 18 ,
178
- 19 ,
179
- 12 ,
180
- 13 ,
181
- 14 ,
182
- 15 ,
183
- 8 ,
184
- 9 ,
185
- 10 ,
186
- 11 ,
187
- 4 ,
188
- 5 ,
189
- 6 ,
190
- 7 ,
191
- 0 ,
192
- 1 ,
193
- 2 ,
194
- 3 );
136
+ std::vector<uint16_t > arr
137
+ = {3 , 2 , 1 , 0 , 7 , 6 , 5 , 4 , 11 , 10 , 9 ,
138
+ 8 , 15 , 14 , 13 , 12 , 19 , 18 , 17 , 16 , 23 , 22 ,
139
+ 21 , 20 , 27 , 26 , 25 , 24 , 31 , 30 , 29 , 28 };
140
+ __m512i mask = _mm512_loadu_si512 (arr.data ());
195
141
v = _mm512_permutexvar_epi16 (mask, v);
196
142
}
197
143
else if constexpr (scale == 8 ) {
198
- __m512i mask = _mm512_set_epi16 (24 ,
199
- 25 ,
200
- 26 ,
201
- 27 ,
202
- 28 ,
203
- 29 ,
204
- 30 ,
205
- 31 ,
206
- 16 ,
207
- 17 ,
208
- 18 ,
209
- 19 ,
210
- 20 ,
211
- 21 ,
212
- 22 ,
213
- 23 ,
214
- 8 ,
215
- 9 ,
216
- 10 ,
217
- 11 ,
218
- 12 ,
219
- 13 ,
220
- 14 ,
221
- 15 ,
222
- 0 ,
223
- 1 ,
224
- 2 ,
225
- 3 ,
226
- 4 ,
227
- 5 ,
228
- 6 ,
229
- 7 );
144
+ std::vector<uint16_t > arr
145
+ = {7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 , 15 , 14 , 13 ,
146
+ 12 , 11 , 10 , 9 , 8 , 23 , 22 , 21 , 20 , 19 , 18 ,
147
+ 17 , 16 , 31 , 30 , 29 , 28 , 27 , 26 , 25 , 24 };
148
+ __m512i mask = _mm512_loadu_si512 (arr.data ());
230
149
v = _mm512_permutexvar_epi16 (mask, v);
231
150
}
232
151
else if constexpr (scale == 16 ) {
233
- __m512i mask = _mm512_set_epi16 (16 ,
234
- 17 ,
235
- 18 ,
236
- 19 ,
237
- 20 ,
238
- 21 ,
239
- 22 ,
240
- 23 ,
241
- 24 ,
242
- 25 ,
243
- 26 ,
244
- 27 ,
245
- 28 ,
246
- 29 ,
247
- 30 ,
248
- 31 ,
249
- 0 ,
250
- 1 ,
251
- 2 ,
252
- 3 ,
253
- 4 ,
254
- 5 ,
255
- 6 ,
256
- 7 ,
257
- 8 ,
258
- 9 ,
259
- 10 ,
260
- 11 ,
261
- 12 ,
262
- 13 ,
263
- 14 ,
264
- 15 );
152
+ std::vector<uint16_t > arr
153
+ = {15 , 14 , 13 , 12 , 11 , 10 , 9 , 8 , 7 , 6 , 5 ,
154
+ 4 , 3 , 2 , 1 , 0 , 31 , 30 , 29 , 28 , 27 , 26 ,
155
+ 25 , 24 , 23 , 22 , 21 , 20 , 19 , 18 , 17 , 16 };
156
+ __m512i mask = _mm512_loadu_si512 (arr.data ());
265
157
v = _mm512_permutexvar_epi16 (mask, v);
266
158
}
267
159
else if constexpr (scale == 32 ) {
0 commit comments