Skip to content

Commit 73474f4

Browse files
committed
Remove conditional check in tsne iDGT2D1gpu kernel
1 parent a64a6b8 commit 73474f4

File tree

5 files changed

+23
-18
lines changed

5 files changed

+23
-18
lines changed

tsne/CUDA/src/kernels/nbodyfft.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,16 @@ void iDFT2D1gpu(thrust::complex<float>* din, thrust::complex<float>* dout, int n
351351
angle = 2.0f * PI * fdividef((float)i, (float)num_cols);
352352
sum = 0.0f;
353353
#pragma unroll
354-
for (int k = 0; k < num_cols; ++k) {
354+
for (int k = 0; k < num_cols/2 + 1; ++k) {
355355
// sincosf(angle * k, &sinf, &cosf);
356356
// twiddle = thrust::complex<float>(cosf, sinf);
357357
TWIDDLE();
358-
if (k < (num_cols/2+1)) {
359-
sum = sum + din[j * (num_cols/2+1) + k] * twiddle;
360-
} else {
361-
sum = sum + thrust::conj(din[((num_rows-j)%num_rows) * (num_cols/2+1) + ((num_cols-k)%num_cols)]) * twiddle;
362-
}
358+
sum = sum + din[j * (num_cols/2+1) + k] * twiddle;
359+
}
360+
#pragma unroll
361+
for (int k = num_cols/2 + 1; k < num_cols; ++k) {
362+
TWIDDLE();
363+
sum = sum + thrust::conj(din[((num_rows-j)%num_rows) * (num_cols/2+1) + ((num_cols-k)%num_cols)]) * twiddle;
363364
}
364365

365366
dout[i * num_rows + j] = sum;

tsne/HIP/src/kernels/nbodyfft.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,16 @@ void iDFT2D1gpu(thrust::complex<float>* din, thrust::complex<float>* dout, int n
360360
angle = 2.0f * PI * fdividef((float)i, (float)num_cols);
361361
sum = 0.0f;
362362
#pragma unroll
363-
for (int k = 0; k < num_cols; ++k) {
363+
for (int k = 0; k < num_cols/2 + 1; ++k) {
364364
// sincosf(angle * k, &sinf, &cosf);
365365
// twiddle = thrust::complex<float>(cosf, sinf);
366366
TWIDDLE();
367-
if (k < (num_cols/2+1)) {
368-
sum = sum + din[j * (num_cols/2+1) + k] * twiddle;
369-
} else {
370-
sum = sum + thrust::conj(din[((num_rows-j)%num_rows) * (num_cols/2+1) + ((num_cols-k)%num_cols)]) * twiddle;
371-
}
367+
sum = sum + din[j * (num_cols/2+1) + k] * twiddle;
368+
}
369+
#pragma unroll
370+
for (int k = num_cols/2 + 1; k < num_cols; ++k) {
371+
TWIDDLE();
372+
sum = sum + thrust::conj(din[((num_rows-j)%num_rows) * (num_cols/2+1) + ((num_cols-k)%num_cols)]) * twiddle;
372373
}
373374

374375
dout[i * num_rows + j] = sum;

tsne/SYCL/src/kernels/nbodyfft.dp.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,17 +358,18 @@ void iDFT2D1gpu(std::complex<float>* din, std::complex<float>* dout, int num_row
358358
angle = 2.0f * PI * ((float)i / (float)num_cols);
359359
sum = 0.0f;
360360
#pragma unroll
361-
for (int k = 0; k < num_cols; ++k) {
361+
for (int k = 0; k < num_cols/2 + 1; ++k) {
362362
// sinf = sycl::sin(angle * k);
363363
// cosf = sycl::cos(angle * k);
364364
// sinf = sycl::sincos(angle * k, sycl::make_ptr<float, sycl::access::address_space::private_space>(&cosf));
365365
// twiddle = std::complex<float>(cosf, sinf);
366366
TWIDDLE();
367-
if (k < (num_cols/2+1)) {
368-
sum = sum + din[j * (num_cols/2+1) + k] * twiddle;
369-
} else {
370-
sum = sum + std::conj(din[((num_rows-j)%num_rows) * (num_cols/2+1) + ((num_cols-k)%num_cols)]) * twiddle;
371-
}
367+
sum = sum + din[j * (num_cols/2+1) + k] * twiddle;
368+
}
369+
#pragma unroll
370+
for (int k = num_cols/2 + 1; k < num_cols; ++k) {
371+
TWIDDLE();
372+
sum = sum + std::conj(din[((num_rows-j)%num_rows) * (num_cols/2+1) + ((num_cols-k)%num_cols)]) * twiddle;
372373
}
373374

374375
dout[i * num_rows + j] = sum;

tsne/SYCL/src/utils/debug_utils.dp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
*
3838
*/
3939

40+
#include <complex>
4041
#include <sycl/sycl.hpp>
4142
#include "include/utils/debug_utils.h"
4243

tsne/SYCL/src/utils/matrix_broadcast_utils.dp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
* Copyright (c) 2018, Regents of the University of California
3838
*/
3939

40+
#include <complex>
4041
#include <sycl/sycl.hpp>
4142
#include "include/utils/matrix_broadcast_utils.h"
4243

0 commit comments

Comments
 (0)