Skip to content

Commit 7646974

Browse files
committed
GPU implementation
1 parent 911a80d commit 7646974

File tree

2 files changed

+74
-9
lines changed

2 files changed

+74
-9
lines changed

source/module_base/blas_connector.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec
671671
}
672672
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
673673
#ifdef __CUDA
674-
vector_mul_vector_complex_wrapper(d, dim, result, vector1, vector2);
674+
vector_mul_vector_gpu(dim, result, vector1, vector2);
675675
#endif
676676
}
677677
}
@@ -691,7 +691,7 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec
691691
}
692692
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
693693
#ifdef __CUDA
694-
vector_div_vector_complex_wrapper(d, dim, result, vector1, vector2);
694+
vector_mul_vector_gpu(dim, result, vector1, vector2);
695695
#endif
696696
}
697697
}

source/module_base/kernels/cuda/math_op.cu

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,25 +192,22 @@ __global__ void vector_div_vector_kernel(
192192
}
193193

194194
template <typename FPTYPE>
195-
inline void vector_div_constant_complex_wrapper(const base_device::DEVICE_GPU* d,
196-
const int dim,
195+
inline void vector_div_vector_complex_wrapper(const int dim,
197196
std::complex<FPTYPE>* result,
198197
const std::complex<FPTYPE>* vector,
199198
const FPTYPE constant)
200199
{
201200
thrust::complex<FPTYPE>* result_tmp = reinterpret_cast<thrust::complex<FPTYPE>*>(result);
202-
const thrust::complex<FPTYPE>* vector_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector);
203-
201+
const thrust::complex<FPTYPE>* vector1_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector1);
204202
int thread = THREADS_PER_BLOCK;
205203
int block = (dim + thread - 1) / thread;
206-
vector_div_constant_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector_tmp, constant);
204+
vector_div_vector_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector1_tmp, vector2);
207205

208206
cudaCheckOnDebug();
209207
}
210208

211209
template <typename FPTYPE>
212-
inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
213-
const int& dim,
210+
inline void vector_mul_vector_complex_wrapper(const int& dim,
214211
std::complex<FPTYPE>* result,
215212
const std::complex<FPTYPE>* vector1,
216213
const FPTYPE* vector2)
@@ -224,4 +221,72 @@ inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
224221
cudaCheckOnDebug();
225222
}
226223

224+
void vector_div_vector_gpu(const int& dim,
225+
double* result,
226+
const double* vector1,
227+
const double* vector2)
228+
{
229+
int thread = THREADS_PER_BLOCK;
230+
int block = (dim + thread - 1) / thread;
231+
vector_div_vector_kernel<double> <<<block, thread >>> (dim, result, vector1, vector2);
232+
233+
cudaCheckOnDebug();
234+
}
235+
236+
void vector_div_vector_gpu(const int& dim,
237+
float* result,
238+
const float* vector1,
239+
const float* vector2)
240+
{
241+
int thread = THREADS_PER_BLOCK;
242+
int block = (dim + thread - 1) / thread;
243+
vector_div_vector_kernel<float> <<<block, thread >>> (dim, result, vector1, vector2);
244+
245+
cudaCheckOnDebug();
246+
}
247+
248+
void vector_div_vector_gpu(const int& dim, std::complex<float>* result, const std::complex<float>* vector1, const float* vector2)
249+
{
250+
vector_div_vector_complex_wrapper(dim, result, vector1, vector2);
251+
}
252+
253+
void vector_div_vector_gpu(const int& dim, std::complex<double>* result, const std::complex<double>* vector1, const double* vector2)
254+
{
255+
vector_div_vector_complex_wrapper(dim, result, vector1, vector2);
256+
}
257+
258+
void vector_mul_vector_gpu(const int& dim,
259+
double* result,
260+
const double* vector1,
261+
const double* vector2)
262+
{
263+
int thread = THREADS_PER_BLOCK;
264+
int block = (dim + thread - 1) / thread;
265+
vector_mul_vector_kernel<double> <<<block, thread >>> (dim, result, vector1, vector2);
266+
267+
cudaCheckOnDebug();
268+
}
269+
270+
void vector_mul_vector_gpu(const int& dim,
271+
float* result,
272+
const float* vector1,
273+
const float* vector2)
274+
{
275+
int thread = THREADS_PER_BLOCK;
276+
int block = (dim + thread - 1) / thread;
277+
vector_mul_vector_kernel<float> <<<block, thread >>> (dim, result, vector1, vector2);
278+
279+
cudaCheckOnDebug();
280+
}
281+
282+
void vector_mul_vector_gpu(const int& dim, std::complex<float>* result, const std::complex<float>* vector1, const float* vector2)
283+
{
284+
vector_mul_vector_complex_wrapper(dim, result, vector1, vector2);
285+
}
286+
287+
void vector_mul_vector_gpu(const int& dim, std::complex<double>* result, const std::complex<double>* vector1, const double* vector2)
288+
{
289+
vector_mul_vector_complex_wrapper(dim, result, vector1, vector2);
290+
}
291+
227292
} // namespace ModuleBase

0 commit comments

Comments
 (0)