Skip to content

Commit ea77a9f

Browse files
committed
[fcorr_1d_usm] updating and revising fcorr_1d_usm
1 parent b3d9e93 commit ea77a9f

File tree

1 file changed

+161
-75
lines changed

1 file changed

+161
-75
lines changed

Libraries/oneMKL/fourier_correlation/fcorr_1d_usm.cpp

Lines changed: 161 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,103 +13,189 @@
1313
#include <mkl.h>
1414
#include <sycl/sycl.hpp>
1515
#include <iostream>
16-
#include <string>
17-
#include <oneapi/mkl/dfti.hpp>
16+
#include <oneapi/mkl/dft.hpp>
1817
#include <oneapi/mkl/rng.hpp>
1918
#include <oneapi/mkl/vm.hpp>
19+
#include <oneapi/mkl/blas.hpp>
2020

21-
template <typename T, typename I>
22-
using maxloc = sycl::ext::oneapi::maximum<std::pair<T, I>>;
21+
template <typename T>
22+
static bool is_device_accessible(const T* x, const sycl::queue& Q) {
23+
sycl::usm::alloc alloc_type = sycl::get_pointer_type(x, Q.get_context());
24+
return (alloc_type == sycl::usm::alloc::shared
25+
|| alloc_type == sycl::usm::alloc::device);
26+
}
2327

24-
constexpr size_t L = 1;
28+
static sycl::event
29+
naive_cross_correlation(sycl::queue& Q,
30+
unsigned int N,
31+
const float* u,
32+
const float* v,
33+
float* w,
34+
const std::vector<sycl::event> deps = {}) {
35+
// u, v and w must be USM allocations of (at least) N device-accessible float
36+
// values (w must be writable)
37+
if (!is_device_accessible(u, Q) ||
38+
!is_device_accessible(v, Q) ||
39+
!is_device_accessible(w, Q)) {
40+
throw std::invalid_argument("Data arrays must be device-accessible");
41+
}
42+
auto ev = Q.parallel_for(sycl::range<1>{N}, deps, [=](sycl::id<1> item) {
43+
size_t s = item.get(0);
44+
w[s] = 0.0f;
45+
for (size_t j = 0; j < N; j++) {
46+
w[s] += u[j] * v[(j - s + N) % N];
47+
}
48+
});
49+
return ev;
50+
}
2551

2652
int main(int argc, char** argv) {
2753
unsigned int N = (argc == 1) ? 32 : std::stoi(argv[1]);
28-
if ((N % 2) != 0) N++;
29-
if (N < 32) N = 32;
54+
// N >= 8 required for the arbitrary signals as defined herein
55+
if (N < 8)
56+
throw std::invalid_argument("The input value N must be 8 or greater.");
57+
58+
// Let s be an integer s.t. 0 <= s < N and let
59+
// corr[s] = \sum_{j = 0}^{N-1} sig1[j] sig2[(j - s + N) mod N]
60+
// be the cross-correlation between two real periodic signals sig1 and sig2
61+
// of period N. This code shows how to calculate corr using Discrete Fourier
62+
// Transforms (DFTs).
63+
// 0 (resp. 1) is returned if naive and DFT-based calculations are (resp.
64+
// are not) within error tolerance of one another.
65+
int return_code = 0;
3066

3167
// Initialize SYCL queue
3268
sycl::queue Q(sycl::default_selector_v);
3369
std::cout << "Running on: "
34-
<< Q.get_device().get_info<sycl::info::device::name>() << "\n";
35-
36-
// Initialize signal and correlation arrays
37-
auto sig1 = sycl::malloc_shared<float>(N + 2, Q);
38-
auto sig2 = sycl::malloc_shared<float>(N + 2, Q);
39-
auto corr = sycl::malloc_shared<float>(N + 2, Q);
40-
41-
// Initialize input signals with artificial data
70+
<< Q.get_device().get_info<sycl::info::device::name>()
71+
<< std::endl;
72+
// Initialize signal and correlation arrays. The arrays must be large enough
73+
// to store the forward and backward domains' data, consisting of N real
74+
// values and (N/2 + 1) complex values, respectively (for the DFT-based
75+
// calculations).
76+
auto sig1 = sycl::malloc_shared<float>(2 * (N / 2 + 1), Q);
77+
auto sig2 = sycl::malloc_shared<float>(2 * (N / 2 + 1), Q);
78+
auto corr = sycl::malloc_shared<float>(2 * (N / 2 + 1), Q);
79+
// Array used for calculating corr without Discrete Fourier Transforms
80+
// (for comparison purposes):
81+
auto naive_corr = sycl::malloc_shared<float>(N, Q);
82+
83+
// Initialize input signals with artificial "noise" data (random values of
84+
// magnitude much smaller than relevant signal data points)
4285
std::uint32_t seed = (unsigned)time(NULL); // Get RNG seed value
4386
oneapi::mkl::rng::mcg31m1 engine(Q, seed); // Initialize RNG engine
4487
// Set RNG distribution
4588
oneapi::mkl::rng::uniform<float, oneapi::mkl::rng::uniform_method::standard>
46-
rng_distribution(-0.00005, 0.00005);
89+
rng_distribution(-0.00005f, 0.00005f);
4790

48-
auto evt1 =
49-
oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1); // Noise
91+
auto evt1 = oneapi::mkl::rng::generate(rng_distribution, engine, N, sig1);
5092
auto evt2 = oneapi::mkl::rng::generate(rng_distribution, engine, N, sig2);
51-
evt1.wait();
52-
evt2.wait();
53-
54-
Q.single_task<>([=]() {
55-
sig1[N - N / 4 - 1] = 1.0;
56-
sig1[N - N / 4] = 1.0;
57-
sig1[N - N / 4 + 1] = 1.0; // Signal
58-
sig2[N / 4 - 1] = 1.0;
59-
sig2[N / 4] = 1.0;
60-
sig2[N / 4 + 1] = 1.0;
61-
})
62-
.wait();
6393

64-
// Initialize FFT descriptor
94+
// Set the (relevant) signal data as shifted versions of one another
95+
auto evt = Q.single_task<>({evt1, evt2}, [=]() {
96+
sig1[N - N / 4 - 1] = 1.0f;
97+
sig1[N - N / 4] = 1.0f;
98+
sig1[N - N / 4 + 1] = 1.0f;
99+
sig2[N / 4 - 1] = 1.0f;
100+
sig2[N / 4] = 1.0f;
101+
sig2[N / 4 + 1] = 1.0f;
102+
});
103+
// Calculate L2 norms of both input signals before proceeding (for
104+
// normalization purposes and for the definition of error tolerance)
105+
float *norm_sig1 = sycl::malloc_shared<float>(1, Q);
106+
float *norm_sig2 = sycl::malloc_shared<float>(1, Q);
107+
evt1 = oneapi::mkl::blas::nrm2(Q, N, sig1, 1, norm_sig1, {evt});
108+
evt2 = oneapi::mkl::blas::nrm2(Q, N, sig2, 1, norm_sig2, {evt});
109+
// 1) Calculate the cross-correlation naively (for verification purposes);
110+
naive_cross_correlation(Q, N, sig1, sig2, naive_corr, {evt}).wait();
111+
// 2) Calculate the cross-correlation via Discrete Fourier Transforms (DFTs):
112+
// corr = (1/N) * iDFT(DFT(sig1) * CONJ(DFT(sig2)))
113+
// Initialize DFT descriptor
65114
oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
66-
oneapi::mkl::dft::domain::REAL>
67-
transform_plan(N);
68-
transform_plan.commit(Q);
69-
70-
// Perform forward transforms on real arrays
71-
evt1 = oneapi::mkl::dft::compute_forward(transform_plan, sig1);
72-
evt2 = oneapi::mkl::dft::compute_forward(transform_plan, sig2);
73-
74-
// Compute: DFT(sig1) * CONJG(DFT(sig2))
75-
oneapi::mkl::vm::mulbyconj(
76-
Q, N / 2, reinterpret_cast<std::complex<float>*>(sig1),
77-
reinterpret_cast<std::complex<float>*>(sig2),
78-
reinterpret_cast<std::complex<float>*>(corr), {evt1, evt2})
79-
.wait();
80-
81-
// Perform backward transform on complex correlation array
82-
oneapi::mkl::dft::compute_backward(transform_plan, corr).wait();
83-
84-
// Find the shift that gives maximum correlation value using parallel maxloc
85-
// reduction
86-
std::pair<float, int>* max_res =
87-
sycl::malloc_shared<std::pair<float, int>>(1, Q);
88-
std::pair<float, int> max_identity = {std::numeric_limits<float>::min(),
89-
std::numeric_limits<int>::min()};
90-
*max_res = max_identity;
91-
auto red_max = reduction(max_res, max_identity, maxloc<float, int>());
92-
93-
Q.parallel_for(sycl::nd_range<1>{N, L}, red_max,
94-
[=](sycl::nd_item<1> item, auto& max_res) {
95-
int i = item.get_global_id(0);
96-
std::pair<float, int> partial = {corr[i], i};
97-
max_res.combine(partial);
98-
})
99-
.wait();
100-
float max_corr = max_res->first;
101-
int shift = max_res->second;
102-
shift =
103-
(shift > N / 2) ? shift - N : shift; // Treat the signals as circularly
104-
// shifted versions of each other.
105-
std::cout << "Shift the second signal " << shift
106-
<< " elements relative to the first signal to get a maximum, "
107-
"normalized correlation score of "
108-
<< max_corr / N << ".\n";
115+
oneapi::mkl::dft::domain::REAL> desc(N);
116+
desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, 1.0f / N);
117+
desc.commit(Q);
118+
// Compute in-place forward transforms of both signals:
119+
// sig1 <- DFT(sig1)
120+
evt1 = oneapi::mkl::dft::compute_forward(desc, sig1, {evt1});
121+
// sig2 <- DFT(sig2)
122+
evt2 = oneapi::mkl::dft::compute_forward(desc, sig2, {evt2});
123+
// Compute the element-wise multipication of (complex) coefficients in
124+
// backward domain:
125+
// corr <- sig1 * CONJ(sig2) [component-wise]
126+
evt = oneapi::mkl::vm::mulbyconj(
127+
Q, N / 2 + 1,
128+
reinterpret_cast<std::complex<float>*>(sig1),
129+
reinterpret_cast<std::complex<float>*>(sig2),
130+
reinterpret_cast<std::complex<float>*>(corr), {evt1, evt2});
131+
// Compute in-place (scaled) backward transform:
132+
// corr <- (1/N) * iDFT(corr)
133+
oneapi::mkl::dft::compute_backward(desc, corr, {evt}).wait();
134+
135+
// Error bound for naive calculations:
136+
float max_err_threshold =
137+
2.0f * std::numeric_limits<float>::epsilon() * norm_sig1[0] * norm_sig2[0];
138+
// Adding an (empirical) error bound for the DFT-based calculation defined as
139+
// epsilon * O(log(N)) * scaling_factor * nrm2(input data),
140+
// wherein (for the last DFT at play)
141+
// - scaling_factor = 1.0 / N;
142+
// - nrm2(input data) = norm_sig1[0] * norm_sig2[0] * N
143+
// - O(log(N)) ~ 2 * log(N) [arbitrary choice; implementation-dependent behavior]
144+
max_err_threshold +=
145+
2.0f * logf(N) * std::numeric_limits<float>::epsilon()
146+
* norm_sig1[0] * norm_sig2[0];
147+
// Verify results by comparing DFT-based and naive calculations to each other,
148+
// and fetch optimal shift maximizing correlation (DFT-based calculation).
149+
float max_err = 0.0f;
150+
float max_corr = corr[0];
151+
int optimal_shift = 0;
152+
for (size_t s = 0; s < N; s++) {
153+
const float local_err = fabs(naive_corr[s] - corr[s]);
154+
if (local_err > max_err)
155+
max_err = local_err;
156+
if (max_err > max_err_threshold) {
157+
std::cerr << "An error was found when verifying the results." << std::endl;
158+
std::cerr << "For shift value s = " << s << ":" << std::endl;
159+
std::cerr << "\tNaive calculation results in " << naive_corr[s] << std::endl;
160+
std::cerr << "\tFourier-based calculation results in " << corr[s] << std::endl;
161+
std::cerr << "The error (" << max_err
162+
<< ") exceeds the threshold value of "
163+
<< max_err_threshold << std::endl;
164+
return_code = 1;
165+
break;
166+
}
167+
if (corr[s] > max_corr) {
168+
max_corr = corr[s];
169+
optimal_shift = s;
170+
}
171+
}
172+
// Conclude:
173+
if (return_code == 0) {
174+
// Get average and standard deviation of either signal for normalizing the
175+
// correlation "score"
176+
const float avg_sig1 = sig1[0] / N;
177+
const float avg_sig2 = sig2[0] / N;
178+
const float std_dev_sig1 =
179+
sqrt((norm_sig1[0] * norm_sig1[0] - N * avg_sig1 * avg_sig1) / N);
180+
const float std_dev_sig2 =
181+
sqrt((norm_sig2[0] * norm_sig2[0] - N * avg_sig2 * avg_sig2) / N);
182+
const float normalized_corr =
183+
(max_corr / N - avg_sig1 * avg_sig2) / (std_dev_sig1 * std_dev_sig2);
184+
std::cout << "Right-shift the second signal " << optimal_shift
185+
<< " elements to get a maximum, normalized correlation score of "
186+
<< normalized_corr
187+
<< " (treating the signals as periodic)." << std::endl;
188+
std::cout << "Max difference between naive and Fourier-based calculations : "
189+
<< max_err << " (verification threshold: " << max_err_threshold
190+
<< ")." << std::endl;
191+
}
109192

110193
// Cleanup
111194
sycl::free(sig1, Q);
112195
sycl::free(sig2, Q);
113196
sycl::free(corr, Q);
114-
sycl::free(max_res, Q);
197+
sycl::free(naive_corr, Q);
198+
sycl::free(norm_sig1, Q);
199+
sycl::free(norm_sig2, Q);
200+
return return_code;
115201
}

0 commit comments

Comments
 (0)