Skip to content

Commit 4cab68f

Browse files
author
Fikret Ardal
committed
complex kernel_transpose for AVX512.
1 parent 753e51b commit 4cab68f

File tree

3 files changed

+123
-18
lines changed

3 files changed

+123
-18
lines changed

c++/nda/simd/arch/AVX/kernel.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ namespace nda::simd {
7474
simd_f8 c0c1c2c3c4c5c6c7 = simd_block[2];
7575
simd_f8 d0d1d2d3d4d5d6d7 = simd_block[3];
7676
simd_f8 e0e1e2e3e4e5e6e7 = simd_block[4];
77-
simd_f8 f0f1f2f3f4f5f6f = simd_block[5];
77+
simd_f8 f0f1f2f3f4f5f6f7 = simd_block[5];
7878
simd_f8 g0g1g2g3g4g5g6g7 = simd_block[6];
7979
simd_f8 h0h1h2h3h4h5h6h7 = simd_block[7];
8080

8181
__m256 a0b0a1b1a4b4a5b5 = _mm256_unpacklo_ps(a0a1a2a3a4a5a6a7, b0b1b2b3b4b5b6b7);
8282
__m256 a2b2a3b3a6b6a7b7 = _mm256_unpackhi_ps(a0a1a2a3a4a5a6a7, b0b1b2b3b4b5b6b7);
8383
__m256 c0d0c1d1c4d4c5d5 = _mm256_unpacklo_ps(c0c1c2c3c4c5c6c7, d0d1d2d3d4d5d6d7);
8484
__m256 c2d2c3d3c6d6c7d7 = _mm256_unpackhi_ps(c0c1c2c3c4c5c6c7, d0d1d2d3d4d5d6d7);
85-
__m256 e0f0e1f1e4f4e5f5 = _mm256_unpacklo_ps(e0e1e2e3e4e5e6e7, f0f1f2f3f4f5f6f);
86-
__m256 e2f2e3f3e6f6e7f7 = _mm256_unpackhi_ps(e0e1e2e3e4e5e6e7, f0f1f2f3f4f5f6f);
85+
__m256 e0f0e1f1e4f4e5f5 = _mm256_unpacklo_ps(e0e1e2e3e4e5e6e7, f0f1f2f3f4f5f6f7);
86+
__m256 e2f2e3f3e6f6e7f7 = _mm256_unpackhi_ps(e0e1e2e3e4e5e6e7, f0f1f2f3f4f5f6f7);
8787
__m256 g0h0g1h1g4h4g5h5 = _mm256_unpacklo_ps(g0g1g2g3g4g5g6g7, h0h1h2h3h4h5h6h7);
8888
__m256 g2h2g3h3g6h6g7h7 = _mm256_unpackhi_ps(g0g1g2g3g4g5g6g7, h0h1h2h3h4h5h6h7);
8989

@@ -135,9 +135,9 @@ namespace nda::simd {
135135
simd_cf4 c0c1c2c3 = simd_block[2];
136136
simd_cf4 d0d1d2d3 = simd_block[3];
137137

138-
__m256d a0b0a2b2 = _mm256_unpacklo_pd(_mm256_castps_pd(a0a1a2a3), _mm256_castps_pd(b0b1b2b3));
138+
__m256d a0b0a2b2 = _mm256_unpacklo_pd(_mm256_castps_pd(a0a1a2a3), _mm256_castps_pd(b0b1b2b3));
139139
__m256d a1b1a3b3 = _mm256_unpackhi_pd(_mm256_castps_pd(a0a1a2a3), _mm256_castps_pd(b0b1b2b3));
140-
__m256d c0d0c2d2 = _mm256_unpacklo_pd(_mm256_castps_pd(c0c1c2c3), _mm256_castps_pd(d0d1d2d3));
140+
__m256d c0d0c2d2 = _mm256_unpacklo_pd(_mm256_castps_pd(c0c1c2c3), _mm256_castps_pd(d0d1d2d3));
141141
__m256d c1d1c3d3 = _mm256_unpackhi_pd(_mm256_castps_pd(c0c1c2c3), _mm256_castps_pd(d0d1d2d3));
142142

143143
simd_cf4 a0b0c0d0(_mm256_castpd_ps(_mm256_permute2f128_pd(a0b0a2b2, c0d0c2d2, 0x20)));
Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,116 @@
1-
#ifdef __AVX512F__
21
#pragma once
2+
#ifdef __AVX512F__
33
#include "./type.hpp"
44
#include "./functions.hpp"
55
#include "../kernel_forward.hpp"
6+
namespace nda::simd {
7+
template <>
8+
inline std::array<simd_i16, 16> kernel_transpose(const std::array<simd_i16, 16> &simd_block) {
9+
return simd_block;
10+
}
11+
12+
13+
template <>
14+
inline std::array<simd_l8, 8> kernel_transpose(const std::array<simd_l8, 8> &simd_block) {
15+
return simd_block;
16+
}
17+
18+
template <>
19+
inline std::array<simd_f16, 16> kernel_transpose(const std::array<simd_f16, 16> &simd_block) {
20+
return simd_block;
21+
}
22+
23+
template <>
24+
inline std::array<simd_d8, 8> kernel_transpose(const std::array<simd_d8, 8> &simd_block) {
25+
return simd_block;
26+
}
27+
28+
template <>
29+
inline std::array<simd_cf8, 8> kernel_transpose(const std::array<simd_cf8, 8> &simd_block) {
30+
simd_cf8 a0a1a2a3a4a5a6a7 = simd_block[0];
31+
simd_cf8 b0b1b2b3b4b5b6b7 = simd_block[1];
32+
simd_cf8 c0c1c2c3c4c5c6c7 = simd_block[2];
33+
simd_cf8 d0d1d2d3d4d5d6d7 = simd_block[3];
34+
simd_cf8 e0e1e2e3e4e5e6e7 = simd_block[4];
35+
simd_cf8 f0f1f2f3f4f5f6f7 = simd_block[5];
36+
simd_cf8 g0g1g2g3g4g5g6g7 = simd_block[6];
37+
simd_cf8 h0h1h2h3h4h5h6h7 = simd_block[7];
38+
39+
__m512d a0b0a2b2a4b4a6b6 = _mm512_unpacklo_pd(_mm512_castps_pd(a0a1a2a3a4a5a6a7), _mm512_castps_pd(b0b1b2b3b4b5b6b7));
40+
__m512d a1b1a3b3a5b5a7b7 = _mm512_unpackhi_pd(_mm512_castps_pd(a0a1a2a3a4a5a6a7), _mm512_castps_pd(b0b1b2b3b4b5b6b7));
41+
__m512d c0d0c2d2c4d4c6d6 = _mm512_unpacklo_pd(_mm512_castps_pd(c0c1c2c3c4c5c6c7), _mm512_castps_pd(d0d1d2d3d4d5d6d7));
42+
__m512d c1d1c3d3c5d5c7d7 = _mm512_unpackhi_pd(_mm512_castps_pd(c0c1c2c3c4c5c6c7), _mm512_castps_pd(d0d1d2d3d4d5d6d7));
43+
__m512d e0f0e2f2e4f4e6f6 = _mm512_unpacklo_pd(_mm512_castps_pd(e0e1e2e3e4e5e6e7), _mm512_castps_pd(f0f1f2f3f4f5f6f7));
44+
__m512d e1f1e3f3e5f5e5f7 = _mm512_unpackhi_pd(_mm512_castps_pd(e0e1e2e3e4e5e6e7), _mm512_castps_pd(f0f1f2f3f4f5f6f7));
45+
__m512d g0h0g2h2g4h4g6h6 = _mm512_unpacklo_pd(_mm512_castps_pd(g0g1g2g3g4g5g6g7), _mm512_castps_pd(h0h1h2h3h4h5h6h7));
46+
__m512d g1h1g3h3g5h5g7h7 = _mm512_unpackhi_pd(_mm512_castps_pd(g0g1g2g3g4g5g6g7), _mm512_castps_pd(h0h1h2h3h4h5h6h7));
47+
48+
__m512d a2b2a0b0a6b6a4b4 = _mm512_shuffle_f64x2(a0b0a2b2a4b4a6b6, a0b0a2b2a4b4a6b6, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
49+
__m512d a3b3a1b1a7b7a5b5 = _mm512_shuffle_f64x2(a1b1a3b3a5b5a7b7, a1b1a3b3a5b5a7b7, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
50+
__m512d c2d2c0d0c6d6c4d4 = _mm512_shuffle_f64x2(c0d0c2d2c4d4c6d6, c0d0c2d2c4d4c6d6, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
51+
__m512d c3d3c1d1c7d7c5d5 = _mm512_shuffle_f64x2(c1d1c3d3c5d5c7d7, c1d1c3d3c5d5c7d7, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
52+
__m512d e2f2e0f0e6f6e4f4 = _mm512_shuffle_f64x2(e0f0e2f2e4f4e6f6, e0f0e2f2e4f4e6f6, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
53+
__m512d e3f3e1f1e7f7e5f5 = _mm512_shuffle_f64x2(e1f1e3f3e5f5e5f7, e1f1e3f3e5f5e5f7, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
54+
__m512d g2h2g0h0g6h6g4h4 = _mm512_shuffle_f64x2(g0h0g2h2g4h4g6h6, g0h0g2h2g4h4g6h6, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
55+
__m512d g3h3g1h1g7h7g5h5 = _mm512_shuffle_f64x2(g1h1g3h3g5h5g7h7, g1h1g3h3g5h5g7h7, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
56+
57+
__m512d a2b2c2d2a6b6c6d6 = _mm512_mask_blend_pd(0b11001100, a2b2a0b0a6b6a4b4, c0d0c2d2c4d4c6d6);
58+
__m512d a3b3c3d3a7b7c7d7 = _mm512_mask_blend_pd(0b11001100, a3b3a1b1a7b7a5b5, c1d1c3d3c5d5c7d7);
59+
__m512d a0b0c0d0a4b4c4d4 = _mm512_mask_blend_pd(0b11001100, a0b0a2b2a4b4a6b6, c2d2c0d0c6d6c4d4);
60+
__m512d a1b1c1d1a5b5c5d5 = _mm512_mask_blend_pd(0b11001100, a1b1a3b3a5b5a7b7, c3d3c1d1c7d7c5d5);
61+
__m512d e2f2g2h2e6f6g6h6 = _mm512_mask_blend_pd(0b11001100, e2f2e0f0e6f6e4f4, g0h0g2h2g4h4g6h6);
62+
__m512d e3f3g3h3e7f7g7h7 = _mm512_mask_blend_pd(0b11001100, e3f3e1f1e7f7e5f5, g1h1g3h3g5h5g7h7);
63+
__m512d e0f0g0h0e4f4g4h4 = _mm512_mask_blend_pd(0b11001100, e0f0e2f2e4f4e6f6, g2h2g0h0g6h6g4h4);
64+
__m512d e1f1g1h1e5f5g5h5 = _mm512_mask_blend_pd(0b11001100, e1f1e3f3e5f5e5f7, g3h3g1h1g7h7g5h5);
65+
66+
__m512d e0f0g0h0e0f0g0h0 = _mm512_shuffle_f64x2(e0f0g0h0e4f4g4h4, e0f0g0h0e4f4g4h4, NDA_SHUFFLE_MASK4(0, 1, 0, 1));
67+
__m512d e1f1g1h1e1f1g1h1 = _mm512_shuffle_f64x2(e1f1g1h1e5f5g5h5, e1f1g1h1e5f5g5h5, NDA_SHUFFLE_MASK4(0, 1, 0, 1));
68+
__m512d e2f2g2h2e2f2g2h2 = _mm512_shuffle_f64x2(e2f2g2h2e6f6g6h6, e2f2g2h2e6f6g6h6, NDA_SHUFFLE_MASK4(0, 1, 0, 1));
69+
__m512d e3f3g3h3e3f3g3h3 = _mm512_shuffle_f64x2(e3f3g3h3e7f7g7h7, e3f3g3h3e7f7g7h7, NDA_SHUFFLE_MASK4(0, 1, 0, 1));
70+
__m512d a4b4c4d4a4b4c4d4 = _mm512_shuffle_f64x2(a0b0c0d0a4b4c4d4, a0b0c0d0a4b4c4d4, NDA_SHUFFLE_MASK4(2, 3, 2, 3));
71+
__m512d a5b5c5d5a5b5c5d5 = _mm512_shuffle_f64x2(a1b1c1d1a5b5c5d5, a1b1c1d1a5b5c5d5, NDA_SHUFFLE_MASK4(2, 3, 2, 3));
72+
__m512d a6b6c6d6a6b6c6d6 = _mm512_shuffle_f64x2(a2b2c2d2a6b6c6d6, a2b2c2d2a6b6c6d6, NDA_SHUFFLE_MASK4(2, 3, 2, 3));
73+
__m512d a7b7c7d7a7b7c7d7 = _mm512_shuffle_f64x2(a3b3c3d3a7b7c7d7, a3b3c3d3a7b7c7d7, NDA_SHUFFLE_MASK4(2, 3, 2, 3));
74+
75+
simd_cf8 a0b0c0d0e0f0g0h0(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a0b0c0d0a4b4c4d4, e0f0g0h0e0f0g0h0)));
76+
77+
simd_cf8 a1b1c1d1e1f1g1h1(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a1b1c1d1a5b5c5d5, e1f1g1h1e1f1g1h1)));
78+
79+
simd_cf8 a2b2c2d2e2f2g2h2(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a2b2c2d2a6b6c6d6, e2f2g2h2e2f2g2h2)));
80+
81+
simd_cf8 a3b3c3d3e3f3g3h3(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a3b3c3d3a7b7c7d7, e3f3g3h3e3f3g3h3)));
82+
83+
simd_cf8 a4b4c4d4e4f4g4h4(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a4b4c4d4a4b4c4d4, e0f0g0h0e4f4g4h4)));
84+
85+
simd_cf8 a5b5c5d5e5f5g5h5(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a5b5c5d5a5b5c5d5, e1f1g1h1e5f5g5h5)));
86+
87+
simd_cf8 a6b6c6d6e6f6g6h6(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a6b6c6d6a6b6c6d6, e2f2g2h2e6f6g6h6)));
88+
89+
simd_cf8 a7b7c7d7e7f7g7h7(_mm512_castpd_ps(_mm512_mask_blend_pd(0b11110000, a7b7c7d7a7b7c7d7, e3f3g3h3e7f7g7h7)));
90+
91+
return {a0b0c0d0e0f0g0h0, a1b1c1d1e1f1g1h1, a2b2c2d2e2f2g2h2, a3b3c3d3e3f3g3h3,
92+
a4b4c4d4e4f4g4h4, a5b5c5d5e5f5g5h5, a6b6c6d6e6f6g6h6, a7b7c7d7e7f7g7h7};
93+
}
94+
95+
template <>
96+
inline std::array<simd_cd4, 4> kernel_transpose(const std::array<simd_cd4, 4> &simd_block) {
97+
simd_cd4 a0a1a2a3 = simd_block[0];
98+
simd_cd4 b0b1b2b3 = simd_block[1];
99+
simd_cd4 c0c1c2c3 = simd_block[2];
100+
simd_cd4 d0d1d2d3 = simd_block[3];
101+
102+
__m512d a0a1b0b1 = _mm512_shuffle_f64x2(a0a1a2a3, b0b1b2b3, NDA_SHUFFLE_MASK4(0, 1, 0, 1));
103+
__m512d c0c1d0d1 = _mm512_shuffle_f64x2(c0c1c2c3, d0d1d2d3, NDA_SHUFFLE_MASK4(0, 1, 0, 1));
104+
__m512d a2a3b2b3 = _mm512_shuffle_f64x2(a0a1a2a3, b0b1b2b3, NDA_SHUFFLE_MASK4(2, 3, 2, 3));
105+
__m512d c2c3d2d3 = _mm512_shuffle_f64x2(c0c1c2c3, d0d1d2d3, NDA_SHUFFLE_MASK4(2, 3, 2, 3));
106+
107+
simd_cd4 a0b0c0d0(_mm512_shuffle_f64x2(a0a1b0b1, c0c1d0d1, NDA_SHUFFLE_MASK4(0, 2, 0, 2)));
108+
simd_cd4 a1b1c1d1(_mm512_shuffle_f64x2(a0a1b0b1, c0c1d0d1, NDA_SHUFFLE_MASK4(1, 3, 1, 3)));
109+
simd_cd4 a2b2c2d2(_mm512_shuffle_f64x2(a2a3b2b3, c2c3d2d3, NDA_SHUFFLE_MASK4(0, 2, 0, 2)));
110+
simd_cd4 a3b3c3d3(_mm512_shuffle_f64x2(a2a3b2b3, c2c3d2d3, NDA_SHUFFLE_MASK4(1, 3, 1, 3)));
111+
112+
return {a0b0c0d0, a1b1c1d1, a2b2c2d2, a3b3c3d3};
113+
}
114+
} // namespace nda::simd
115+
6116
#endif

test/c++/nda_simd.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -658,11 +658,11 @@ void simd_kernel_transpose() {
658658
std::array<simd_t, Width> simd_block;
659659
std::array<std::array<T, Width>, Width> array_block;
660660
for (int i = 0; i < Width; ++i) {
661-
// std::array<T, Width> tmp = generate_random_array<T, Width>();
662-
std::array<T, Width> tmp;
663-
for (int j = 0; j < Width ; ++j) {
664-
tmp[j] = Width * i + j;
665-
}
661+
std::array<T, Width> tmp = generate_random_array<T, Width>();
662+
// std::array<T, Width> tmp;
663+
// for (int j = 0; j < Width ; ++j) {
664+
// tmp[j] = Width * i + j;
665+
// }
666666
simd_block[i].load_unaligned(tmp.data());
667667
array_block[i] = tmp;
668668
}
@@ -1639,8 +1639,8 @@ TEST(NDA, SimdKernelTranspose) {
16391639
// simd_kernel_transpose<double, 8, abi_tag::AVX512>();
16401640
// simd_kernel_transpose<int32_t, 16, abi_tag::AVX512>();
16411641
// simd_kernel_transpose<int64_t, 8, abi_tag::AVX512>();
1642-
// simd_kernel_transpose<std::complex<float>, 8, abi_tag::AVX512>();
1643-
// simd_kernel_transpose<std::complex<double>, 4, abi_tag::AVX512>();
1642+
simd_kernel_transpose<std::complex<float>, 8, abi_tag::AVX512>();
1643+
simd_kernel_transpose<std::complex<double>, 4, abi_tag::AVX512>();
16441644
#endif
16451645

16461646
}
@@ -1685,11 +1685,6 @@ TEST(NDA, OurSIMD) {
16851685
};
16861686
using dcomplex = std::complex<double>;
16871687
using simd_t = native_simd<dcomplex>;
1688-
std::array<simd_t, 2> tests;
1689-
tests[0] = simd_t({0, 1, 2, 3});
1690-
tests[1] = simd_t({4, 5, 6, 7});
1691-
auto testing = simd::kernel_transpose(tests);
1692-
// auto testing = simd::transpose(tests);
16931688

16941689
const long size1 = 2;
16951690
const long size2 = 10;

0 commit comments

Comments
 (0)