@@ -154,4 +154,69 @@ void cal_ylm_real_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_dev
154154template struct cal_ylm_real_op <float , base_device::DEVICE_GPU>;
155155template struct cal_ylm_real_op <double , base_device::DEVICE_GPU>;
156156
157+
158+ // The next are kernels for new blas_connector
159+
160+
161+ template <typename T>
162+ __global__ void vector_mul_vector_kernel (
163+ const int size,
164+ T* result,
165+ const T* vector1,
166+ const typename GetTypeReal<T>::type* vector2)
167+ {
168+ int i = blockIdx .x * blockDim .x + threadIdx .x ;
169+ if (i < size)
170+ {
171+ result[i] = vector1[i] * vector2[i];
172+ }
173+ }
174+
175+ template <typename T>
176+ __global__ void vector_div_vector_kernel (
177+ const int size,
178+ T* result,
179+ const T* vector1,
180+ const typename GetTypeReal<T>::type* vector2)
181+ {
182+ int i = blockIdx .x * blockDim .x + threadIdx .x ;
183+ if (i < size)
184+ {
185+ result[i] = vector1[i] / vector2[i];
186+ }
187+ }
188+
189+ template <typename FPTYPE>
190+ inline void vector_div_constant_complex_wrapper (const base_device::DEVICE_GPU* d,
191+ const int dim,
192+ std::complex <FPTYPE>* result,
193+ const std::complex <FPTYPE>* vector,
194+ const FPTYPE constant)
195+ {
196+ thrust::complex <FPTYPE>* result_tmp = reinterpret_cast <thrust::complex <FPTYPE>*>(result);
197+ const thrust::complex <FPTYPE>* vector_tmp = reinterpret_cast <const thrust::complex <FPTYPE>*>(vector);
198+
199+ int thread = thread_per_block;
200+ int block = (dim + thread - 1 ) / thread;
201+ vector_div_constant_kernel<thrust::complex <FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector_tmp, constant);
202+
203+ cudaCheckOnDebug ();
204+ }
205+
206+ template <typename FPTYPE>
207+ inline void vector_mul_vector_complex_wrapper (const base_device::DEVICE_GPU* d,
208+ const int & dim,
209+ std::complex <FPTYPE>* result,
210+ const std::complex <FPTYPE>* vector1,
211+ const FPTYPE* vector2)
212+ {
213+ thrust::complex <FPTYPE>* result_tmp = reinterpret_cast <thrust::complex <FPTYPE>*>(result);
214+ const thrust::complex <FPTYPE>* vector1_tmp = reinterpret_cast <const thrust::complex <FPTYPE>*>(vector1);
215+ int thread = thread_per_block;
216+ int block = (dim + thread - 1 ) / thread;
217+ vector_mul_vector_kernel<thrust::complex <FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector1_tmp, vector2);
218+
219+ cudaCheckOnDebug ();
220+ }
221+
157222} // namespace ModuleBase
0 commit comments