33#include " module_base/macros.h"
44
55#include < base/macros/macros.h>
6+ #include < cuda_runtime.h>
7+ #include < thrust/complex.h>
8+ #include < thrust/execution_policy.h>
9+ #include < thrust/inner_product.h>
610
711namespace ModuleBase {
812
@@ -197,7 +201,7 @@ inline void vector_div_constant_complex_wrapper(const base_device::DEVICE_GPU* d
197201 thrust::complex <FPTYPE>* result_tmp = reinterpret_cast <thrust::complex <FPTYPE>*>(result);
198202 const thrust::complex <FPTYPE>* vector_tmp = reinterpret_cast <const thrust::complex <FPTYPE>*>(vector);
199203
200- int thread = thread_per_block ;
204+ int thread = THREADS_PER_BLOCK ;
201205 int block = (dim + thread - 1 ) / thread;
202206 vector_div_constant_kernel<thrust::complex <FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector_tmp, constant);
203207
@@ -213,7 +217,7 @@ inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
213217{
214218 thrust::complex <FPTYPE>* result_tmp = reinterpret_cast <thrust::complex <FPTYPE>*>(result);
215219 const thrust::complex <FPTYPE>* vector1_tmp = reinterpret_cast <const thrust::complex <FPTYPE>*>(vector1);
216- int thread = thread_per_block ;
220+ int thread = THREADS_PER_BLOCK ;
217221 int block = (dim + thread - 1 ) / thread;
218222 vector_mul_vector_kernel<thrust::complex <FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector1_tmp, vector2);
219223
0 commit comments