From d9ac32e2e47628f40634a74901a17ef83a090ad9 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 21 Oct 2025 13:16:25 +0800 Subject: [PATCH] Add float copy function in blas connector --- .../module_external/blas_connector.h | 8 ++++++++ .../module_external/blas_connector_vector.cpp | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/source/source_base/module_external/blas_connector.h b/source/source_base/module_external/blas_connector.h index 921f94ddb9..ea6834ce2d 100644 --- a/source/source_base/module_external/blas_connector.h +++ b/source/source_base/module_external/blas_connector.h @@ -24,7 +24,9 @@ extern "C" void caxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); void zaxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); + void scopy_(long const *n, const float *a, int const *incx, float *b, int const *incy); void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy); + void ccopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); void zcopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); //reason for passing results as argument instead of returning it: @@ -340,6 +342,12 @@ class BlasConnector static void copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static + void copy(const long n, const float *a, const int incx, float *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static void copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); diff --git a/source/source_base/module_external/blas_connector_vector.cpp b/source/source_base/module_external/blas_connector_vector.cpp index b5e0972946..8f8091ba7d 100644 --- a/source/source_base/module_external/blas_connector_vector.cpp +++ b/source/source_base/module_external/blas_connector_vector.cpp @@ -322,6 +322,16 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in } // copies a into b +void BlasConnector::copy(const long n, const float *a, const int incx, float *b, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + scopy_(&n, a, &incx, b, &incy); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { @@ -332,6 +342,16 @@ void BlasConnector::copy(const long n, const double *a, const int incx, double * } } +void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + ccopy_(&n, a, &incx, b, &incy); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) {