Skip to content

Commit 6060b43

Browse files
committed
Finish CUDA kernel
1 parent f8e9ae6 commit 6060b43

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

source/module_base/blas_connector.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <thrust/execution_policy.h>
1414
#include <thrust/inner_product.h>
1515
#include "module_base/tool_quit.h"
16+
#include "module_base/kernels/cuda/math_op.cu"
1617

1718
#include "cublas_v2.h"
1819

@@ -668,6 +669,11 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec
668669
result[i] = vector1[i] * vector2[i];
669670
}
670671
}
672+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
673+
#ifdef __CUDA
674+
vector_mul_vector_complex_wrapper(d, dim, result, vector1, vector2);
675+
#endif
676+
}
671677
}
672678

673679

@@ -683,4 +689,9 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec
683689
result[i] = vector1[i] / vector2[i];
684690
}
685691
}
692+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
693+
#ifdef __CUDA
694+
vector_div_vector_complex_wrapper(d, dim, result, vector1, vector2);
695+
#endif
696+
}
686697
}

source/module_base/kernels/cuda/math_op.cu

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,69 @@ void cal_ylm_real_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_dev
154154
template struct cal_ylm_real_op<float, base_device::DEVICE_GPU>;
155155
template struct cal_ylm_real_op<double, base_device::DEVICE_GPU>;
156156

157+
158+
// The next are kernels for new blas_connector
159+
160+
161+
template <typename T>
162+
__global__ void vector_mul_vector_kernel(
163+
const int size,
164+
T* result,
165+
const T* vector1,
166+
const typename GetTypeReal<T>::type* vector2)
167+
{
168+
int i = blockIdx.x * blockDim.x + threadIdx.x;
169+
if (i < size)
170+
{
171+
result[i] = vector1[i] * vector2[i];
172+
}
173+
}
174+
175+
template <typename T>
176+
__global__ void vector_div_vector_kernel(
177+
const int size,
178+
T* result,
179+
const T* vector1,
180+
const typename GetTypeReal<T>::type* vector2)
181+
{
182+
int i = blockIdx.x * blockDim.x + threadIdx.x;
183+
if (i < size)
184+
{
185+
result[i] = vector1[i] / vector2[i];
186+
}
187+
}
188+
189+
template <typename FPTYPE>
190+
inline void vector_div_constant_complex_wrapper(const base_device::DEVICE_GPU* d,
191+
const int dim,
192+
std::complex<FPTYPE>* result,
193+
const std::complex<FPTYPE>* vector,
194+
const FPTYPE constant)
195+
{
196+
thrust::complex<FPTYPE>* result_tmp = reinterpret_cast<thrust::complex<FPTYPE>*>(result);
197+
const thrust::complex<FPTYPE>* vector_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector);
198+
199+
int thread = thread_per_block;
200+
int block = (dim + thread - 1) / thread;
201+
vector_div_constant_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector_tmp, constant);
202+
203+
cudaCheckOnDebug();
204+
}
205+
206+
template <typename FPTYPE>
207+
inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
208+
const int& dim,
209+
std::complex<FPTYPE>* result,
210+
const std::complex<FPTYPE>* vector1,
211+
const FPTYPE* vector2)
212+
{
213+
thrust::complex<FPTYPE>* result_tmp = reinterpret_cast<thrust::complex<FPTYPE>*>(result);
214+
const thrust::complex<FPTYPE>* vector1_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector1);
215+
int thread = thread_per_block;
216+
int block = (dim + thread - 1) / thread;
217+
vector_mul_vector_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector1_tmp, vector2);
218+
219+
cudaCheckOnDebug();
220+
}
221+
157222
} // namespace ModuleBase

0 commit comments

Comments
 (0)