Skip to content

Commit 31d4dff

Browse files
committed
add vector_add_vector kernel
1 parent 91e8dc2 commit 31d4dff

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

source/module_base/blas_connector.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,4 +691,76 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec
691691
hsolver::vector_div_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
692692
#endif
693693
}
694+
}
695+
696+
void vector_add_vector(const int& dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type)
697+
{
698+
if (device_type == base_device::CpuDevice){
699+
#ifdef _OPENMP
700+
#pragma omp parallel for schedule(static, 8192 / sizeof(float))
701+
#endif
702+
for (int i = 0; i < dim; i++)
703+
{
704+
result[i] = vector1[i] * constant1 + vector2[i] * constant2;
705+
}
706+
}
707+
else if (device_type == base_device::GpuDevice){
708+
#ifdef __CUDA
709+
hsolver::constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
710+
#endif
711+
}
712+
}
713+
714+
void vector_add_vector(const int& dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type)
715+
{
716+
if (device_type == base_device::CpuDevice){
717+
#ifdef _OPENMP
718+
#pragma omp parallel for schedule(static, 8192 / sizeof(double))
719+
#endif
720+
for (int i = 0; i < dim; i++)
721+
{
722+
result[i] = vector1[i] * constant1 + vector2[i] * constant2;
723+
}
724+
}
725+
else if (device_type == base_device::GpuDevice){
726+
#ifdef __CUDA
727+
hsolver::constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
728+
#endif
729+
}
730+
}
731+
732+
void vector_add_vector(const int& dim, std::complex<float> *result, const std::complex<float> *vector1, const float constant1, const std::complex<float> *vector2, const float constant2, base_device::AbacusDevice_t device_type)
733+
{
734+
if (device_type == base_device::CpuDevice){
735+
#ifdef _OPENMP
736+
#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex<float>))
737+
#endif
738+
for (int i = 0; i < dim; i++)
739+
{
740+
result[i] = vector1[i] * constant1 + vector2[i] * constant2;
741+
}
742+
}
743+
else if (device_type == base_device::GpuDevice){
744+
#ifdef __CUDA
745+
hsolver::constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
746+
#endif
747+
}
748+
}
749+
750+
void vector_add_vector(const int& dim, std::complex<double> *result, const std::complex<double> *vector1, const double constant1, const std::complex<double> *vector2, const double constant2, base_device::AbacusDevice_t device_type)
751+
{
752+
if (device_type == base_device::CpuDevice){
753+
#ifdef _OPENMP
754+
#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex<double>))
755+
#endif
756+
for (int i = 0; i < dim; i++)
757+
{
758+
result[i] = vector1[i] * constant1 + vector2[i] * constant2;
759+
}
760+
}
761+
else if (device_type == base_device::GpuDevice){
762+
#ifdef __CUDA
763+
hsolver::constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
764+
#endif
765+
}
694766
}

source/module_base/blas_connector.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,19 @@ class BlasConnector
312312
template <typename T>
313313
static
314314
void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
315+
316+
// y = alpha * x + beta * y
317+
static
318+
void vector_add_vector(const int& dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
319+
320+
static
321+
void vector_add_vector(const int& dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
322+
323+
static
324+
void vector_add_vector(const int& dim, std::complex<float> *result, const std::complex<float> *vector1, const float constant1, const std::complex<float> *vector2, const float constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
325+
326+
static
327+
void vector_add_vector(const int& dim, std::complex<double> *result, const std::complex<double> *vector1, const double constant1, const std::complex<double> *vector2, const double constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
315328
};
316329

317330
// If GATHER_INFO is defined, the original function is replaced with a "i" suffix,

source/module_hsolver/kernels/cuda/math_kernel_op.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,7 @@ template struct line_minimize_with_block_op<std::complex<float>, base_device::DE
10431043
template struct vector_div_constant_op<std::complex<float>, base_device::DEVICE_GPU>;
10441044
template struct vector_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
10451045
template struct vector_div_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
1046+
template struct constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>;
10461047
template struct constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>;
10471048
template struct matrixSetToAnother<std::complex<float>, base_device::DEVICE_GPU>;
10481049

0 commit comments

Comments
 (0)