|
13 | 13 | #include <mkl.h>
|
14 | 14 | #include <sycl/sycl.hpp>
|
15 | 15 | #include <iostream>
|
16 |
| -#include <string> |
17 |
| -#include <oneapi/mkl/dfti.hpp> |
| 16 | +#include <oneapi/mkl/dft.hpp> |
18 | 17 | #include <oneapi/mkl/rng.hpp>
|
19 | 18 | #include <oneapi/mkl/vm.hpp>
|
| 19 | +#include <oneapi/mkl/blas.hpp> |
20 | 20 |
|
21 |
| -template <typename T, typename I> |
22 |
| -using maxloc = sycl::ext::oneapi::maximum<std::pair<T, I>>; |
| 21 | +template <typename T> |
| 22 | +static bool is_device_accessible(const T* x, const sycl::queue& Q) { |
| 23 | + sycl::usm::alloc alloc_type = sycl::get_pointer_type(x, Q.get_context()); |
| 24 | + return (alloc_type == sycl::usm::alloc::shared |
| 25 | + || alloc_type == sycl::usm::alloc::device); |
| 26 | +} |
23 | 27 |
|
24 |
| -constexpr size_t L = 1; |
| 28 | +static sycl::event |
| 29 | +naive_cross_correlation(sycl::queue& Q, |
| 30 | + unsigned int N, |
| 31 | + const float* u, |
| 32 | + const float* v, |
| 33 | + float* w, |
| 34 | + const std::vector<sycl::event> deps = {}) { |
| 35 | + // u, v and w must be USM allocations of (at least) N device-accessible float |
| 36 | + // values (w must be writable) |
| 37 | + if (!is_device_accessible(u, Q) || |
| 38 | + !is_device_accessible(v, Q) || |
| 39 | + !is_device_accessible(w, Q)) { |
| 40 | + throw std::invalid_argument("Data arrays must be device-accessible"); |
| 41 | + } |
| 42 | + auto ev = Q.parallel_for(sycl::range<1>{N}, deps, [=](sycl::id<1> item) { |
| 43 | + size_t s = item.get(0); |
| 44 | + w[s] = 0.0f; |
| 45 | + for (size_t j = 0; j < N; j++) { |
| 46 | + w[s] += u[j] * v[(j - s + N) % N]; |
| 47 | + } |
| 48 | + }); |
| 49 | + return ev; |
| 50 | +} |
25 | 51 |
|
26 | 52 | int main(int argc, char** argv) {
|
27 | 53 | unsigned int N = (argc == 1) ? 32 : std::stoi(argv[1]);
|
28 |
| - if ((N % 2) != 0) N++; |
29 |
| - if (N < 32) N = 32; |
| 54 | + // N >= 8 required for the arbitrary signals as defined herein |
| 55 | + if (N < 8) |
| 56 | + throw std::invalid_argument("The input value N must be 8 or greater."); |
| 57 | + |
| 58 | + // Let s be an integer s.t. 0 <= s < N and let |
| 59 | + // corr[s] = \sum_{j = 0}^{N-1} sig1[j] sig2[(j - s + N) mod N] |
| 60 | + // be the cross-correlation between two real periodic signals sig1 and sig2 |
| 61 | + // of period N. This code shows how to calculate corr using Discrete Fourier |
| 62 | + // Transforms (DFTs). |
| 63 | + // 0 (resp. 1) is returned if naive and DFT-based calculations are (resp. |
| 64 | + // are not) within error tolerance of one another. |
| 65 | + int return_code = 0; |
30 | 66 |
|
31 | 67 | // Initialize SYCL queue
|
32 | 68 | sycl::queue Q(sycl::default_selector_v);
|
33 | 69 | std::cout << "Running on: "
|
34 |
| - << Q.get_device().get_info<sycl::info::device::name>() << "\n"; |
35 |
| - |
36 |
| - // Initialize signal and correlation arrays |
37 |
| - auto sig1 = sycl::malloc_shared<float>(N + 2, Q); |
38 |
| - auto sig2 = sycl::malloc_shared<float>(N + 2, Q); |
39 |
| - auto corr = sycl::malloc_shared<float>(N + 2, Q); |
40 |
| - |
41 |
| - // Initialize input signals with artificial data |
| 70 | + << Q.get_device().get_info<sycl::info::device::name>() |
| 71 | + << std::endl; |
| 72 | + // Initialize signal and correlation arrays. The arrays must be large enough |
| 73 | + // to store the forward and backward domains' data, consisting of N real |
| 74 | + // values and (N/2 + 1) complex values, respectively (for the DFT-based |
| 75 | + // calculations). |
| 76 | + auto sig1 = sycl::malloc_shared<float>(2 * (N / 2 + 1), Q); |
| 77 | + auto sig2 = sycl::malloc_shared<float>(2 * (N / 2 + 1), Q); |
| 78 | + auto corr = sycl::malloc_shared<float>(2 * (N / 2 + 1), Q); |
| 79 | + // Array used for calculating corr without Discrete Fourier Transforms |
| 80 | + // (for comparison purposes): |
| 81 | + auto naive_corr = sycl::malloc_shared<float>(N, Q); |
| 82 | + |
| 83 | + // Initialize input signals with artificial "noise" data (random values of |
| 84 | + // magnitude much smaller than relevant signal data points) |
42 | 85 | std::uint32_t seed = (unsigned)time(NULL); // Get RNG seed value
|
43 | 86 | oneapi::mkl::rng::mcg31m1 engine(Q, seed); // Initialize RNG engine
|
44 | 87 | // Set RNG distribution
|
45 | 88 | oneapi::mkl::rng::uniform<float, oneapi::mkl::rng::uniform_method::standard>
|
46 |
| - rng_distribution(-0.00005, 0.00005); |
| 89 | + rng_distribution(-0.00005f, 0.00005f); |
47 | 90 |
|
48 |
| - auto evt1 = |
49 |
| - oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1); // Noise |
| 91 | + auto evt1 = oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1); |
50 | 92 | auto evt2 = oneapi::mkl::rng::generate(rng_distribution, engine, N, sig2);
|
51 |
| - evt1.wait(); |
52 |
| - evt2.wait(); |
53 |
| - |
54 |
| - Q.single_task<>([=]() { |
55 |
| - sig1[N - N / 4 - 1] = 1.0; |
56 |
| - sig1[N - N / 4] = 1.0; |
57 |
| - sig1[N - N / 4 + 1] = 1.0; // Signal |
58 |
| - sig2[N / 4 - 1] = 1.0; |
59 |
| - sig2[N / 4] = 1.0; |
60 |
| - sig2[N / 4 + 1] = 1.0; |
61 |
| - }) |
62 |
| - .wait(); |
63 | 93 |
|
64 |
| - // Initialize FFT descriptor |
| 94 | + // Set the (relevant) signal data as shifted versions of one another |
| 95 | + auto evt = Q.single_task<>({evt1, evt2}, [=]() { |
| 96 | + sig1[N - N / 4 - 1] = 1.0f; |
| 97 | + sig1[N - N / 4] = 1.0f; |
| 98 | + sig1[N - N / 4 + 1] = 1.0f; |
| 99 | + sig2[N / 4 - 1] = 1.0f; |
| 100 | + sig2[N / 4] = 1.0f; |
| 101 | + sig2[N / 4 + 1] = 1.0f; |
| 102 | + }); |
| 103 | + // Calculate L2 norms of both input signals before proceeding (for |
| 104 | + // normalization purposes and for the definition of error tolerance) |
| 105 | + float *norm_sig1 = sycl::malloc_shared<float>(1, Q); |
| 106 | + float *norm_sig2 = sycl::malloc_shared<float>(1, Q); |
| 107 | + evt1 = oneapi::mkl::blas::nrm2(Q, N, sig1, 1, norm_sig1, {evt}); |
| 108 | + evt2 = oneapi::mkl::blas::nrm2(Q, N, sig2, 1, norm_sig2, {evt}); |
| 109 | + // 1) Calculate the cross-correlation naively (for verification purposes); |
| 110 | + naive_cross_correlation(Q, N, sig1, sig2, naive_corr, {evt}).wait(); |
| 111 | + // 2) Calculate the cross-correlation via Discrete Fourier Transforms (DFTs): |
| 112 | + // corr = (1/N) * iDFT(DFT(sig1) * CONJ(DFT(sig2))) |
| 113 | + // Initialize DFT descriptor |
65 | 114 | oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
|
66 |
| - oneapi::mkl::dft::domain::REAL> |
67 |
| - transform_plan(N); |
68 |
| - transform_plan.commit(Q); |
69 |
| - |
70 |
| - // Perform forward transforms on real arrays |
71 |
| - evt1 = oneapi::mkl::dft::compute_forward(transform_plan, sig1); |
72 |
| - evt2 = oneapi::mkl::dft::compute_forward(transform_plan, sig2); |
73 |
| - |
74 |
| - // Compute: DFT(sig1) * CONJG(DFT(sig2)) |
75 |
| - oneapi::mkl::vm::mulbyconj( |
76 |
| - Q, N / 2, reinterpret_cast<std::complex<float>*>(sig1), |
77 |
| - reinterpret_cast<std::complex<float>*>(sig2), |
78 |
| - reinterpret_cast<std::complex<float>*>(corr), {evt1, evt2}) |
79 |
| - .wait(); |
80 |
| - |
81 |
| - // Perform backward transform on complex correlation array |
82 |
| - oneapi::mkl::dft::compute_backward(transform_plan, corr).wait(); |
83 |
| - |
84 |
| - // Find the shift that gives maximum correlation value using parallel maxloc |
85 |
| - // reduction |
86 |
| - std::pair<float, int>* max_res = |
87 |
| - sycl::malloc_shared<std::pair<float, int>>(1, Q); |
88 |
| - std::pair<float, int> max_identity = {std::numeric_limits<float>::min(), |
89 |
| - std::numeric_limits<int>::min()}; |
90 |
| - *max_res = max_identity; |
91 |
| - auto red_max = reduction(max_res, max_identity, maxloc<float, int>()); |
92 |
| - |
93 |
| - Q.parallel_for(sycl::nd_range<1>{N, L}, red_max, |
94 |
| - [=](sycl::nd_item<1> item, auto& max_res) { |
95 |
| - int i = item.get_global_id(0); |
96 |
| - std::pair<float, int> partial = {corr[i], i}; |
97 |
| - max_res.combine(partial); |
98 |
| - }) |
99 |
| - .wait(); |
100 |
| - float max_corr = max_res->first; |
101 |
| - int shift = max_res->second; |
102 |
| - shift = |
103 |
| - (shift > N / 2) ? shift - N : shift; // Treat the signals as circularly |
104 |
| - // shifted versions of each other. |
105 |
| - std::cout << "Shift the second signal " << shift |
106 |
| - << " elements relative to the first signal to get a maximum, " |
107 |
| - "normalized correlation score of " |
108 |
| - << max_corr / N << ".\n"; |
| 115 | + oneapi::mkl::dft::domain::REAL> desc(N); |
| 116 | + desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, 1.0f / N); |
| 117 | + desc.commit(Q); |
| 118 | + // Compute in-place forward transforms of both signals: |
| 119 | + // sig1 <- DFT(sig1) |
| 120 | + evt1 = oneapi::mkl::dft::compute_forward(desc, sig1, {evt1}); |
| 121 | + // sig2 <- DFT(sig2) |
| 122 | + evt2 = oneapi::mkl::dft::compute_forward(desc, sig2, {evt2}); |
| 123 | + // Compute the element-wise multipication of (complex) coefficients in |
| 124 | + // backward domain: |
| 125 | + // corr <- sig1 * CONJ(sig2) [component-wise] |
| 126 | + evt = oneapi::mkl::vm::mulbyconj( |
| 127 | + Q, N / 2 + 1, |
| 128 | + reinterpret_cast<std::complex<float>*>(sig1), |
| 129 | + reinterpret_cast<std::complex<float>*>(sig2), |
| 130 | + reinterpret_cast<std::complex<float>*>(corr), {evt1, evt2}); |
| 131 | + // Compute in-place (scaled) backward transform: |
| 132 | + // corr <- (1/N) * iDFT(corr) |
| 133 | + oneapi::mkl::dft::compute_backward(desc, corr, {evt}).wait(); |
| 134 | + |
| 135 | + // Error bound for naive calculations: |
| 136 | + float max_err_threshold = |
| 137 | + 2.0f * std::numeric_limits<float>::epsilon() * norm_sig1[0] * norm_sig2[0]; |
| 138 | + // Adding an (empirical) error bound for the DFT-based calculation defined as |
| 139 | + // epsilon * O(log(N)) * scaling_factor * nrm2(input data), |
| 140 | + // wherein (for the last DFT at play) |
| 141 | + // - scaling_factor = 1.0 / N; |
| 142 | + // - nrm2(input data) = norm_sig1[0] * norm_sig2[0] * N |
| 143 | + // - O(log(N)) ~ 2 * log(N) [arbitrary choice; implementation-dependent behavior] |
| 144 | + max_err_threshold += |
| 145 | + 2.0f * logf(N) * std::numeric_limits<float>::epsilon() |
| 146 | + * norm_sig1[0] * norm_sig2[0]; |
| 147 | + // Verify results by comparing DFT-based and naive calculations to each other, |
| 148 | + // and fetch optimal shift maximizing correlation (DFT-based calculation). |
| 149 | + float max_err = 0.0f; |
| 150 | + float max_corr = corr[0]; |
| 151 | + int optimal_shift = 0; |
| 152 | + for (size_t s = 0; s < N; s++) { |
| 153 | + const float local_err = fabs(naive_corr[s] - corr[s]); |
| 154 | + if (local_err > max_err) |
| 155 | + max_err = local_err; |
| 156 | + if (max_err > max_err_threshold) { |
| 157 | + std::cerr << "An error was found when verifying the results." << std::endl; |
| 158 | + std::cerr << "For shift value s = " << s << ":" << std::endl; |
| 159 | + std::cerr << "\tNaive calculation results in " << naive_corr[s] << std::endl; |
| 160 | + std::cerr << "\tFourier-based calculation results in " << corr[s] << std::endl; |
| 161 | + std::cerr << "The error (" << max_err |
| 162 | + << ") exceeds the threshold value of " |
| 163 | + << max_err_threshold << std::endl; |
| 164 | + return_code = 1; |
| 165 | + break; |
| 166 | + } |
| 167 | + if (corr[s] > max_corr) { |
| 168 | + max_corr = corr[s]; |
| 169 | + optimal_shift = s; |
| 170 | + } |
| 171 | + } |
| 172 | + // Conclude: |
| 173 | + if (return_code == 0) { |
| 174 | + // Get average and standard deviation of either signal for normalizing the |
| 175 | + // correlation "score" |
| 176 | + const float avg_sig1 = sig1[0] / N; |
| 177 | + const float avg_sig2 = sig2[0] / N; |
| 178 | + const float std_dev_sig1 = |
| 179 | + sqrt((norm_sig1[0] * norm_sig1[0] - N * avg_sig1 * avg_sig1) / N); |
| 180 | + const float std_dev_sig2 = |
| 181 | + sqrt((norm_sig2[0] * norm_sig2[0] - N * avg_sig2 * avg_sig2) / N); |
| 182 | + const float normalized_corr = |
| 183 | + (max_corr / N - avg_sig1 * avg_sig2) / (std_dev_sig1 * std_dev_sig2); |
| 184 | + std::cout << "Right-shift the second signal " << optimal_shift |
| 185 | + << " elements to get a maximum, normalized correlation score of " |
| 186 | + << normalized_corr |
| 187 | + << " (treating the signals as periodic)." << std::endl; |
| 188 | + std::cout << "Max difference between naive and Fourier-based calculations : " |
| 189 | + << max_err << " (verification threshold: " << max_err_threshold |
| 190 | + << ")." << std::endl; |
| 191 | + } |
109 | 192 |
|
110 | 193 | // Cleanup
|
111 | 194 | sycl::free(sig1, Q);
|
112 | 195 | sycl::free(sig2, Q);
|
113 | 196 | sycl::free(corr, Q);
|
114 |
| - sycl::free(max_res, Q); |
| 197 | + sycl::free(naive_corr, Q); |
| 198 | + sycl::free(norm_sig1, Q); |
| 199 | + sycl::free(norm_sig2, Q); |
| 200 | + return return_code; |
115 | 201 | }
|
0 commit comments