Skip to content

Commit 911a80d

Browse files
committed
Fix typename
1 parent 2bc7097 commit 911a80d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

source/module_base/kernels/cuda/math_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include "module_base/macros.h"
44

55
#include <base/macros/macros.h>
6+
#include <cuda_runtime.h>
7+
#include <thrust/complex.h>
8+
#include <thrust/execution_policy.h>
9+
#include <thrust/inner_product.h>
610

711
namespace ModuleBase {
812

@@ -197,7 +201,7 @@ inline void vector_div_constant_complex_wrapper(const base_device::DEVICE_GPU* d
197201
thrust::complex<FPTYPE>* result_tmp = reinterpret_cast<thrust::complex<FPTYPE>*>(result);
198202
const thrust::complex<FPTYPE>* vector_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector);
199203

200-
int thread = thread_per_block;
204+
int thread = THREADS_PER_BLOCK;
201205
int block = (dim + thread - 1) / thread;
202206
vector_div_constant_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector_tmp, constant);
203207

@@ -213,7 +217,7 @@ inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
213217
{
214218
thrust::complex<FPTYPE>* result_tmp = reinterpret_cast<thrust::complex<FPTYPE>*>(result);
215219
const thrust::complex<FPTYPE>* vector1_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector1);
216-
int thread = thread_per_block;
220+
int thread = THREADS_PER_BLOCK;
217221
int block = (dim + thread - 1) / thread;
218222
vector_mul_vector_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector1_tmp, vector2);
219223

0 commit comments

Comments
 (0)