|
6 | 6 | //
|
7 | 7 | // Content:
|
8 | 8 | // This code implements the 1D Fourier correlation algorithm
|
9 |
| -// using SYCL, oneMKL, oneDPL, and explicit buffering. |
| 9 | +// using SYCL, oneMKL, and explicit buffering. |
10 | 10 | //
|
11 | 11 | // =============================================================
|
12 | 12 |
|
13 |
| -#include <oneapi/dpl/algorithm> |
14 |
| -#include <oneapi/dpl/execution> |
15 |
| -#include <oneapi/dpl/iterator> |
16 |
| - |
| 13 | +#include <mkl.h> |
17 | 14 | #include <sycl/sycl.hpp>
|
18 |
| -#include <oneapi/mkl/dfti.hpp> |
| 15 | +#include <iostream> |
| 16 | +#include <oneapi/mkl/dft.hpp> |
19 | 17 | #include <oneapi/mkl/rng.hpp>
|
20 | 18 | #include <oneapi/mkl/vm.hpp>
|
21 |
| -#include <mkl.h> |
| 19 | +#include <oneapi/mkl/blas.hpp> |
22 | 20 |
|
23 |
| -#include <iostream> |
24 |
| -#include <string> |
| 21 | +static void |
| 22 | +naive_cross_correlation(sycl::queue& Q, |
| 23 | + unsigned int N, |
| 24 | + sycl::buffer<float>& u, |
| 25 | + sycl::buffer<float>& v, |
| 26 | + sycl::buffer<float>& w) { |
| 27 | + const size_t min_byte_size = N * sizeof(float); |
| 28 | + if (u.byte_size() < min_byte_size || |
| 29 | + v.byte_size() < min_byte_size || |
| 30 | + w.byte_size() < min_byte_size) { |
| 31 | + throw std::invalid_argument("All buffers must contain at least N float values"); |
| 32 | + } |
| 33 | + Q.submit([&](sycl::handler &cgh) { |
| 34 | + auto u_acc = u.get_access<sycl::access::mode::read>(cgh); |
| 35 | + auto v_acc = v.get_access<sycl::access::mode::read>(cgh); |
| 36 | + auto w_acc = w.get_access<sycl::access::mode::write>(cgh); |
| 37 | + cgh.parallel_for(sycl::range<1>{N}, [=](sycl::id<1> id) { |
| 38 | + const size_t s = id.get(0); |
| 39 | + w_acc[s] = 0.0f; |
| 40 | + for (size_t j = 0; j < N; j++) { |
| 41 | + w_acc[s] += u_acc[j] * v_acc[(j - s + N) % N]; |
| 42 | + } |
| 43 | + }); |
| 44 | + }); |
| 45 | +} |
25 | 46 |
|
26 | 47 | int main(int argc, char **argv) {
|
27 | 48 | unsigned int N = (argc == 1) ? 32 : std::stoi(argv[1]);
|
28 |
| - if ((N % 2) != 0) N++; |
29 |
| - if (N < 32) N = 32; |
| 49 | + // N >= 8 required for the arbitrary signals as defined herein |
| 50 | + if (N < 8) |
| 51 | + throw std::invalid_argument("The input value N must be 8 or greater."); |
| 52 | + |
| 53 | + // Let s be an integer s.t. 0 <= s < N and let |
| 54 | + // corr[s] = \sum_{j = 0}^{N-1} sig1[j] sig2[(j - s + N) mod N] |
| 55 | + // be the cross-correlation between two real periodic signals sig1 and sig2 |
| 56 | + // of period N. This code shows how to calculate corr using Discrete Fourier |
| 57 | + // Transforms (DFTs). |
| 58 | + // 0 (resp. 1) is returned if naive and DFT-based calculations are (resp. |
| 59 | + // are not) within error tolerance of one another. |
| 60 | + int return_code = 0; |
30 | 61 |
|
31 | 62 | // Initialize SYCL queue
|
32 | 63 | sycl::queue Q(sycl::default_selector_v);
|
33 | 64 | std::cout << "Running on: "
|
34 |
| - << Q.get_device().get_info<sycl::info::device::name>() << "\n"; |
35 |
| - |
36 |
| - // Create buffers for signal data. This will only be used on the device. |
37 |
| - sycl::buffer<float> sig1_buf{N + 2}; |
38 |
| - sycl::buffer<float> sig2_buf{N + 2}; |
39 |
| - sycl::buffer<float> corr_buf{N + 2}; |
| 65 | + << Q.get_device().get_info<sycl::info::device::name>() |
| 66 | + << std::endl; |
| 67 | + // Initialize signal and correlation buffers. The buffers must be large enough |
| 68 | + // to store the forward and backward domains' data, consisting of N real |
| 69 | + // values and (N/2 + 1) complex values, respectively (for the DFT-based |
| 70 | + // calculations). |
| 71 | + sycl::buffer<float> sig1{2 * (N / 2 + 1)}; |
| 72 | + sycl::buffer<float> sig2{2 * (N / 2 + 1)}; |
| 73 | + sycl::buffer<float> corr{2 * (N / 2 + 1)}; |
| 74 | + // Buffer used for calculating corr without Discrete Fourier Transforms |
| 75 | + // (for comparison purposes): |
| 76 | + sycl::buffer<float> naive_corr{N}; |
40 | 77 |
|
41 |
| - // Initialize the input signals with artificial data |
| 78 | + // Initialize input signals with artificial "noise" data (random values of |
| 79 | + // magnitude much smaller than relevant signal data points) |
42 | 80 | std::uint32_t seed = (unsigned)time(NULL); // Get RNG seed value
|
43 | 81 | oneapi::mkl::rng::mcg31m1 engine(Q, seed); // Initialize RNG engine
|
44 | 82 | // Set RNG distribution
|
45 | 83 | oneapi::mkl::rng::uniform<float, oneapi::mkl::rng::uniform_method::standard>
|
46 |
| - rng_distribution(-0.00005, 0.00005); |
| 84 | + rng_distribution(-0.00005f, 0.00005f); |
47 | 85 |
|
48 |
| - oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1_buf); // Noise |
49 |
| - oneapi::mkl::rng::generate(rng_distribution, engine, N, sig2_buf); |
| 86 | + oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1); |
| 87 | + oneapi::mkl::rng::generate(rng_distribution, engine, N, sig2); |
50 | 88 |
|
51 |
| - Q.submit([&](sycl::handler &h) { |
52 |
| - sycl::accessor sig1_acc{sig1_buf, h, sycl::write_only}; |
53 |
| - sycl::accessor sig2_acc{sig2_buf, h, sycl::write_only}; |
54 |
| - h.single_task<>([=]() { |
55 |
| - sig1_acc[N - N / 4 - 1] = 1.0; |
56 |
| - sig1_acc[N - N / 4] = 1.0; |
57 |
| - sig1_acc[N - N / 4 + 1] = 1.0; // Signal |
58 |
| - sig2_acc[N / 4 - 1] = 1.0; |
59 |
| - sig2_acc[N / 4] = 1.0; |
60 |
| - sig2_acc[N / 4 + 1] = 1.0; |
| 89 | + // Set the (relevant) signal data as shifted versions of one another |
| 90 | + Q.submit([&](sycl::handler &cgh) { |
| 91 | + sycl::accessor sig1_acc{sig1, cgh, sycl::write_only}; |
| 92 | + sycl::accessor sig2_acc{sig2, cgh, sycl::write_only}; |
| 93 | + cgh.single_task<>([=]() { |
| 94 | + sig1_acc[N - N / 4 - 1] = 1.0f; |
| 95 | + sig1_acc[N - N / 4] = 1.0f; |
| 96 | + sig1_acc[N - N / 4 + 1] = 1.0f; |
| 97 | + sig2_acc[N / 4 - 1] = 1.0f; |
| 98 | + sig2_acc[N / 4] = 1.0f; |
| 99 | + sig2_acc[N / 4 + 1] = 1.0f; |
61 | 100 | });
|
62 |
| - }); // End signal initialization |
63 |
| - |
64 |
| - // Initialize FFT descriptor |
| 101 | + }); |
| 102 | + // Calculate L2 norms of both input signals before proceeding (for |
| 103 | + // normalization purposes and for the definition of error tolerance) |
| 104 | + float norm_sig1, norm_sig2; |
| 105 | + { |
| 106 | + sycl::buffer<float> temp{1}; |
| 107 | + oneapi::mkl::blas::nrm2(Q, N, sig1, 1, temp); |
| 108 | + norm_sig1 = temp.get_host_access(sycl::read_only)[0]; |
| 109 | + oneapi::mkl::blas::nrm2(Q, N, sig2, 1, temp); |
| 110 | + norm_sig2 = temp.get_host_access(sycl::read_only)[0]; |
| 111 | + } |
| 112 | + // 1) Calculate the cross-correlation naively (for verification purposes); |
| 113 | + naive_cross_correlation(Q, N, sig1, sig2, naive_corr); |
| 114 | + // 2) Calculate the cross-correlation via Discrete Fourier Transforms (DFTs): |
| 115 | + // corr = (1/N) * iDFT(DFT(sig1) * CONJ(DFT(sig2))) |
| 116 | + // Initialize DFT descriptor |
65 | 117 | oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
|
66 |
| - oneapi::mkl::dft::domain::REAL> |
67 |
| - transform_plan(N); |
68 |
| - transform_plan.commit(Q); |
69 |
| - |
70 |
| - // Perform forward transforms on real arrays |
71 |
| - oneapi::mkl::dft::compute_forward(transform_plan, sig1_buf); |
72 |
| - oneapi::mkl::dft::compute_forward(transform_plan, sig2_buf); |
73 |
| - |
74 |
| - // Compute: DFT(sig1) * CONJG(DFT(sig2)) |
75 |
| - auto sig1_buf_cplx = |
76 |
| - sig1_buf.template reinterpret<std::complex<float>, 1>((N + 2) / 2); |
77 |
| - auto sig2_buf_cplx = |
78 |
| - sig2_buf.template reinterpret<std::complex<float>, 1>((N + 2) / 2); |
79 |
| - auto corr_buf_cplx = |
80 |
| - corr_buf.template reinterpret<std::complex<float>, 1>((N + 2) / 2); |
81 |
| - oneapi::mkl::vm::mulbyconj(Q, N / 2, sig1_buf_cplx, sig2_buf_cplx, |
82 |
| - corr_buf_cplx); |
83 |
| - |
84 |
| - // Perform backward transform on complex correlation array |
85 |
| - oneapi::mkl::dft::compute_backward(transform_plan, corr_buf); |
86 |
| - |
87 |
| - // Find the shift that gives maximum correlation value |
88 |
| - auto policy = oneapi::dpl::execution::make_device_policy(Q); |
89 |
| - auto maxloc = oneapi::dpl::max_element(policy, |
90 |
| - oneapi::dpl::begin(corr_buf), |
91 |
| - oneapi::dpl::end(corr_buf)); |
92 |
| - int shift = oneapi::dpl::distance(oneapi::dpl::begin(corr_buf), maxloc); |
93 |
| - float max_corr = corr_buf.get_host_access()[shift]; |
| 118 | + oneapi::mkl::dft::domain::REAL> desc(N); |
| 119 | + desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, 1.0f / N); |
| 120 | + desc.commit(Q); |
| 121 | + // Compute in-place forward transforms of both signals: |
| 122 | + // sig1 <- DFT(sig1) |
| 123 | + oneapi::mkl::dft::compute_forward(desc, sig1); |
| 124 | + // sig2 <- DFT(sig2) |
| 125 | + oneapi::mkl::dft::compute_forward(desc, sig2); |
| 126 | + // Compute the element-wise multipication of (complex) coefficients in |
| 127 | + // backward domain: |
| 128 | + // corr <- sig1 * CONJ(sig2) [component-wise] |
| 129 | + auto sig1_cplx = |
| 130 | + sig1.template reinterpret<std::complex<float>, 1>(N / 2 + 1); |
| 131 | + auto sig2_cplx = |
| 132 | + sig2.template reinterpret<std::complex<float>, 1>(N / 2 + 1); |
| 133 | + auto corr_cplx = |
| 134 | + corr.template reinterpret<std::complex<float>, 1>(N / 2 + 1); |
| 135 | + oneapi::mkl::vm::mulbyconj(Q, N / 2 + 1, |
| 136 | + sig1_cplx, sig2_cplx, corr_cplx); |
| 137 | + // Compute in-place (scaled) backward transform: |
| 138 | + // corr <- (1/N) * iDFT(corr) |
| 139 | + oneapi::mkl::dft::compute_backward(desc, corr); |
94 | 140 |
|
95 |
| - shift = |
96 |
| - (shift > N / 2) ? shift - N : shift; // Treat the signals as circularly |
97 |
| - // shifted versions of each other. |
98 |
| - std::cout << "Shift the second signal " << shift |
99 |
| - << " elements relative to the first signal to get a maximum, " |
100 |
| - "normalized correlation score of " |
101 |
| - << max_corr / N << ".\n"; |
| 141 | + // Error bound for naive calculations: |
| 142 | + float max_err_threshold = |
| 143 | + 2.0f * std::numeric_limits<float>::epsilon() * norm_sig1 * norm_sig2; |
| 144 | + // Adding an (empirical) error bound for the DFT-based calculation defined as |
| 145 | + // epsilon * O(log(N)) * scaling_factor * nrm2(input data), |
| 146 | + // wherein (for the last DFT at play) |
| 147 | + // - scaling_factor = 1.0 / N; |
| 148 | + // - nrm2(input data) = norm_sig1 * norm_sig2 * N |
| 149 | + // - O(log(N)) ~ 2 * log(N) [arbitrary choice; implementation-dependent behavior] |
| 150 | + max_err_threshold += |
| 151 | + 2.0f * logf(N) * std::numeric_limits<float>::epsilon() |
| 152 | + * norm_sig1 * norm_sig2; |
| 153 | + // Verify results by comparing DFT-based and naive calculations to each other, |
| 154 | + // and fetch optimal shift maximizing correlation (DFT-based calculation). |
| 155 | + auto naive_corr_acc = naive_corr.get_host_access(sycl::read_only); |
| 156 | + auto corr_acc = corr.get_host_access(sycl::read_only); |
| 157 | + float max_err = 0.0f; |
| 158 | + float max_corr = corr_acc[0]; |
| 159 | + int optimal_shift = 0; |
| 160 | + for (size_t s = 0; s < N; s++) { |
| 161 | + const float local_err = fabs(naive_corr_acc[s] - corr_acc[s]); |
| 162 | + if (local_err > max_err) |
| 163 | + max_err = local_err; |
| 164 | + if (max_err > max_err_threshold) { |
| 165 | + std::cerr << "An error was found when verifying the results." << std::endl; |
| 166 | + std::cerr << "For shift value s = " << s << ":" << std::endl; |
| 167 | + std::cerr << "\tNaive calculation results in " << naive_corr_acc[s] << std::endl; |
| 168 | + std::cerr << "\tFourier-based calculation results in " << corr_acc[s] << std::endl; |
| 169 | + std::cerr << "The error (" << max_err |
| 170 | + << ") exceeds the threshold value of " |
| 171 | + << max_err_threshold << std::endl; |
| 172 | + return_code = 1; |
| 173 | + break; |
| 174 | + } |
| 175 | + if (corr_acc[s] > max_corr) { |
| 176 | + max_corr = corr_acc[s]; |
| 177 | + optimal_shift = s; |
| 178 | + } |
| 179 | + } |
| 180 | + // Conclude: |
| 181 | + if (return_code == 0) { |
| 182 | + // Get average and standard deviation of either signal for normalizing the |
| 183 | + // correlation "score" |
| 184 | + const float avg_sig1 = sig1.get_host_access(sycl::read_only)[0] / N; |
| 185 | + const float avg_sig2 = sig2.get_host_access(sycl::read_only)[0] / N; |
| 186 | + const float std_dev_sig1 = |
| 187 | + sqrt((norm_sig1 * norm_sig1 - N * avg_sig1 * avg_sig1) / N); |
| 188 | + const float std_dev_sig2 = |
| 189 | + sqrt((norm_sig2 * norm_sig2 - N * avg_sig2 * avg_sig2) / N); |
| 190 | + const float normalized_corr = |
| 191 | + (max_corr / N - avg_sig1 * avg_sig2) / (std_dev_sig1 * std_dev_sig2); |
| 192 | + std::cout << "Right-shift the second signal " << optimal_shift |
| 193 | + << " elements to get a maximum, normalized correlation score of " |
| 194 | + << normalized_corr |
| 195 | + << " (treating the signals as periodic)." << std::endl; |
| 196 | + std::cout << "Max difference between naive and Fourier-based calculations : " |
| 197 | + << max_err << " (verification threshold: " << max_err_threshold |
| 198 | + << ")." << std::endl; |
| 199 | + } |
| 200 | + return return_code; |
102 | 201 | }
|
0 commit comments