Skip to content

Commit 73d83b7

Browse files
committed
[fcorr_1d_buffer] updating and revising fcorr_1d_buffer
1 parent ea77a9f commit 73d83b7

File tree

1 file changed

+168
-69
lines changed

1 file changed

+168
-69
lines changed

Libraries/oneMKL/fourier_correlation/fcorr_1d_buffers.cpp

Lines changed: 168 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,97 +6,196 @@
66
//
77
// Content:
88
// This code implements the 1D Fourier correlation algorithm
9-
// using SYCL, oneMKL, oneDPL, and explicit buffering.
9+
// using SYCL, oneMKL, and explicit buffering.
1010
//
1111
// =============================================================
1212

13-
#include <oneapi/dpl/algorithm>
14-
#include <oneapi/dpl/execution>
15-
#include <oneapi/dpl/iterator>
16-
13+
#include <mkl.h>
1714
#include <sycl/sycl.hpp>
18-
#include <oneapi/mkl/dfti.hpp>
15+
#include <iostream>
16+
#include <oneapi/mkl/dft.hpp>
1917
#include <oneapi/mkl/rng.hpp>
2018
#include <oneapi/mkl/vm.hpp>
21-
#include <mkl.h>
19+
#include <oneapi/mkl/blas.hpp>
2220

23-
#include <iostream>
24-
#include <string>
21+
static void
22+
naive_cross_correlation(sycl::queue& Q,
23+
unsigned int N,
24+
sycl::buffer<float>& u,
25+
sycl::buffer<float>& v,
26+
sycl::buffer<float>& w) {
27+
const size_t min_byte_size = N * sizeof(float);
28+
if (u.byte_size() < min_byte_size ||
29+
v.byte_size() < min_byte_size ||
30+
w.byte_size() < min_byte_size) {
31+
throw std::invalid_argument("All buffers must contain at least N float values");
32+
}
33+
Q.submit([&](sycl::handler &cgh) {
34+
auto u_acc = u.get_access<sycl::access::mode::read>(cgh);
35+
auto v_acc = v.get_access<sycl::access::mode::read>(cgh);
36+
auto w_acc = w.get_access<sycl::access::mode::write>(cgh);
37+
cgh.parallel_for(sycl::range<1>{N}, [=](sycl::id<1> id) {
38+
const size_t s = id.get(0);
39+
w_acc[s] = 0.0f;
40+
for (size_t j = 0; j < N; j++) {
41+
w_acc[s] += u_acc[j] * v_acc[(j - s + N) % N];
42+
}
43+
});
44+
});
45+
}
2546

2647
int main(int argc, char **argv) {
2748
unsigned int N = (argc == 1) ? 32 : std::stoi(argv[1]);
28-
if ((N % 2) != 0) N++;
29-
if (N < 32) N = 32;
49+
// N >= 8 required for the arbitrary signals as defined herein
50+
if (N < 8)
51+
throw std::invalid_argument("The input value N must be 8 or greater.");
52+
53+
// Let s be an integer s.t. 0 <= s < N and let
54+
// corr[s] = \sum_{j = 0}^{N-1} sig1[j] sig2[(j - s + N) mod N]
55+
// be the cross-correlation between two real periodic signals sig1 and sig2
56+
// of period N. This code shows how to calculate corr using Discrete Fourier
57+
// Transforms (DFTs).
58+
// 0 (resp. 1) is returned if naive and DFT-based calculations are (resp.
59+
// are not) within error tolerance of one another.
60+
int return_code = 0;
3061

3162
// Initialize SYCL queue
3263
sycl::queue Q(sycl::default_selector_v);
3364
std::cout << "Running on: "
34-
<< Q.get_device().get_info<sycl::info::device::name>() << "\n";
35-
36-
// Create buffers for signal data. This will only be used on the device.
37-
sycl::buffer<float> sig1_buf{N + 2};
38-
sycl::buffer<float> sig2_buf{N + 2};
39-
sycl::buffer<float> corr_buf{N + 2};
65+
<< Q.get_device().get_info<sycl::info::device::name>()
66+
<< std::endl;
67+
// Initialize signal and correlation buffers. The buffers must be large enough
68+
// to store the forward and backward domains' data, consisting of N real
69+
// values and (N/2 + 1) complex values, respectively (for the DFT-based
70+
// calculations).
71+
sycl::buffer<float> sig1{2 * (N / 2 + 1)};
72+
sycl::buffer<float> sig2{2 * (N / 2 + 1)};
73+
sycl::buffer<float> corr{2 * (N / 2 + 1)};
74+
// Buffer used for calculating corr without Discrete Fourier Transforms
75+
// (for comparison purposes):
76+
sycl::buffer<float> naive_corr{N};
4077

41-
// Initialize the input signals with artificial data
78+
// Initialize input signals with artificial "noise" data (random values of
79+
// magnitude much smaller than relevant signal data points)
4280
std::uint32_t seed = (unsigned)time(NULL); // Get RNG seed value
4381
oneapi::mkl::rng::mcg31m1 engine(Q, seed); // Initialize RNG engine
4482
// Set RNG distribution
4583
oneapi::mkl::rng::uniform<float, oneapi::mkl::rng::uniform_method::standard>
46-
rng_distribution(-0.00005, 0.00005);
84+
rng_distribution(-0.00005f, 0.00005f);
4785

48-
oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1_buf); // Noise
49-
oneapi::mkl::rng::generate(rng_distribution, engine, N, sig2_buf);
86+
oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1);
87+
oneapi::mkl::rng::generate(rng_distribution, engine, N, sig2);
5088

51-
Q.submit([&](sycl::handler &h) {
52-
sycl::accessor sig1_acc{sig1_buf, h, sycl::write_only};
53-
sycl::accessor sig2_acc{sig2_buf, h, sycl::write_only};
54-
h.single_task<>([=]() {
55-
sig1_acc[N - N / 4 - 1] = 1.0;
56-
sig1_acc[N - N / 4] = 1.0;
57-
sig1_acc[N - N / 4 + 1] = 1.0; // Signal
58-
sig2_acc[N / 4 - 1] = 1.0;
59-
sig2_acc[N / 4] = 1.0;
60-
sig2_acc[N / 4 + 1] = 1.0;
89+
// Set the (relevant) signal data as shifted versions of one another
90+
Q.submit([&](sycl::handler &cgh) {
91+
sycl::accessor sig1_acc{sig1, cgh, sycl::write_only};
92+
sycl::accessor sig2_acc{sig2, cgh, sycl::write_only};
93+
cgh.single_task<>([=]() {
94+
sig1_acc[N - N / 4 - 1] = 1.0f;
95+
sig1_acc[N - N / 4] = 1.0f;
96+
sig1_acc[N - N / 4 + 1] = 1.0f;
97+
sig2_acc[N / 4 - 1] = 1.0f;
98+
sig2_acc[N / 4] = 1.0f;
99+
sig2_acc[N / 4 + 1] = 1.0f;
61100
});
62-
}); // End signal initialization
63-
64-
// Initialize FFT descriptor
101+
});
102+
// Calculate L2 norms of both input signals before proceeding (for
103+
// normalization purposes and for the definition of error tolerance)
104+
float norm_sig1, norm_sig2;
105+
{
106+
sycl::buffer<float> temp{1};
107+
oneapi::mkl::blas::nrm2(Q, N, sig1, 1, temp);
108+
norm_sig1 = temp.get_host_access(sycl::read_only)[0];
109+
oneapi::mkl::blas::nrm2(Q, N, sig2, 1, temp);
110+
norm_sig2 = temp.get_host_access(sycl::read_only)[0];
111+
}
112+
// 1) Calculate the cross-correlation naively (for verification purposes);
113+
naive_cross_correlation(Q, N, sig1, sig2, naive_corr);
114+
// 2) Calculate the cross-correlation via Discrete Fourier Transforms (DFTs):
115+
// corr = (1/N) * iDFT(DFT(sig1) * CONJ(DFT(sig2)))
116+
// Initialize DFT descriptor
65117
oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
66-
oneapi::mkl::dft::domain::REAL>
67-
transform_plan(N);
68-
transform_plan.commit(Q);
69-
70-
// Perform forward transforms on real arrays
71-
oneapi::mkl::dft::compute_forward(transform_plan, sig1_buf);
72-
oneapi::mkl::dft::compute_forward(transform_plan, sig2_buf);
73-
74-
// Compute: DFT(sig1) * CONJG(DFT(sig2))
75-
auto sig1_buf_cplx =
76-
sig1_buf.template reinterpret<std::complex<float>, 1>((N + 2) / 2);
77-
auto sig2_buf_cplx =
78-
sig2_buf.template reinterpret<std::complex<float>, 1>((N + 2) / 2);
79-
auto corr_buf_cplx =
80-
corr_buf.template reinterpret<std::complex<float>, 1>((N + 2) / 2);
81-
oneapi::mkl::vm::mulbyconj(Q, N / 2, sig1_buf_cplx, sig2_buf_cplx,
82-
corr_buf_cplx);
83-
84-
// Perform backward transform on complex correlation array
85-
oneapi::mkl::dft::compute_backward(transform_plan, corr_buf);
86-
87-
// Find the shift that gives maximum correlation value
88-
auto policy = oneapi::dpl::execution::make_device_policy(Q);
89-
auto maxloc = oneapi::dpl::max_element(policy,
90-
oneapi::dpl::begin(corr_buf),
91-
oneapi::dpl::end(corr_buf));
92-
int shift = oneapi::dpl::distance(oneapi::dpl::begin(corr_buf), maxloc);
93-
float max_corr = corr_buf.get_host_access()[shift];
118+
oneapi::mkl::dft::domain::REAL> desc(N);
119+
desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, 1.0f / N);
120+
desc.commit(Q);
121+
// Compute in-place forward transforms of both signals:
122+
// sig1 <- DFT(sig1)
123+
oneapi::mkl::dft::compute_forward(desc, sig1);
124+
// sig2 <- DFT(sig2)
125+
oneapi::mkl::dft::compute_forward(desc, sig2);
126+
// Compute the element-wise multipication of (complex) coefficients in
127+
// backward domain:
128+
// corr <- sig1 * CONJ(sig2) [component-wise]
129+
auto sig1_cplx =
130+
sig1.template reinterpret<std::complex<float>, 1>(N / 2 + 1);
131+
auto sig2_cplx =
132+
sig2.template reinterpret<std::complex<float>, 1>(N / 2 + 1);
133+
auto corr_cplx =
134+
corr.template reinterpret<std::complex<float>, 1>(N / 2 + 1);
135+
oneapi::mkl::vm::mulbyconj(Q, N / 2 + 1,
136+
sig1_cplx, sig2_cplx, corr_cplx);
137+
// Compute in-place (scaled) backward transform:
138+
// corr <- (1/N) * iDFT(corr)
139+
oneapi::mkl::dft::compute_backward(desc, corr);
94140

95-
shift =
96-
(shift > N / 2) ? shift - N : shift; // Treat the signals as circularly
97-
// shifted versions of each other.
98-
std::cout << "Shift the second signal " << shift
99-
<< " elements relative to the first signal to get a maximum, "
100-
"normalized correlation score of "
101-
<< max_corr / N << ".\n";
141+
// Error bound for naive calculations:
142+
float max_err_threshold =
143+
2.0f * std::numeric_limits<float>::epsilon() * norm_sig1 * norm_sig2;
144+
// Adding an (empirical) error bound for the DFT-based calculation defined as
145+
// epsilon * O(log(N)) * scaling_factor * nrm2(input data),
146+
// wherein (for the last DFT at play)
147+
// - scaling_factor = 1.0 / N;
148+
// - nrm2(input data) = norm_sig1 * norm_sig2 * N
149+
// - O(log(N)) ~ 2 * log(N) [arbitrary choice; implementation-dependent behavior]
150+
max_err_threshold +=
151+
2.0f * logf(N) * std::numeric_limits<float>::epsilon()
152+
* norm_sig1 * norm_sig2;
153+
// Verify results by comparing DFT-based and naive calculations to each other,
154+
// and fetch optimal shift maximizing correlation (DFT-based calculation).
155+
auto naive_corr_acc = naive_corr.get_host_access(sycl::read_only);
156+
auto corr_acc = corr.get_host_access(sycl::read_only);
157+
float max_err = 0.0f;
158+
float max_corr = corr_acc[0];
159+
int optimal_shift = 0;
160+
for (size_t s = 0; s < N; s++) {
161+
const float local_err = fabs(naive_corr_acc[s] - corr_acc[s]);
162+
if (local_err > max_err)
163+
max_err = local_err;
164+
if (max_err > max_err_threshold) {
165+
std::cerr << "An error was found when verifying the results." << std::endl;
166+
std::cerr << "For shift value s = " << s << ":" << std::endl;
167+
std::cerr << "\tNaive calculation results in " << naive_corr_acc[s] << std::endl;
168+
std::cerr << "\tFourier-based calculation results in " << corr_acc[s] << std::endl;
169+
std::cerr << "The error (" << max_err
170+
<< ") exceeds the threshold value of "
171+
<< max_err_threshold << std::endl;
172+
return_code = 1;
173+
break;
174+
}
175+
if (corr_acc[s] > max_corr) {
176+
max_corr = corr_acc[s];
177+
optimal_shift = s;
178+
}
179+
}
180+
// Conclude:
181+
if (return_code == 0) {
182+
// Get average and standard deviation of either signal for normalizing the
183+
// correlation "score"
184+
const float avg_sig1 = sig1.get_host_access(sycl::read_only)[0] / N;
185+
const float avg_sig2 = sig2.get_host_access(sycl::read_only)[0] / N;
186+
const float std_dev_sig1 =
187+
sqrt((norm_sig1 * norm_sig1 - N * avg_sig1 * avg_sig1) / N);
188+
const float std_dev_sig2 =
189+
sqrt((norm_sig2 * norm_sig2 - N * avg_sig2 * avg_sig2) / N);
190+
const float normalized_corr =
191+
(max_corr / N - avg_sig1 * avg_sig2) / (std_dev_sig1 * std_dev_sig2);
192+
std::cout << "Right-shift the second signal " << optimal_shift
193+
<< " elements to get a maximum, normalized correlation score of "
194+
<< normalized_corr
195+
<< " (treating the signals as periodic)." << std::endl;
196+
std::cout << "Max difference between naive and Fourier-based calculations : "
197+
<< max_err << " (verification threshold: " << max_err_threshold
198+
<< ")." << std::endl;
199+
}
200+
return return_code;
102201
}

0 commit comments

Comments
 (0)