Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mt_(&transb, &transa, &n, &m, &k,
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mt_(&transb, &transa, &n, &m, &k,
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mt_(&transb, &transa, &n, &m, &k,
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -147,7 +147,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mt_(&transb, &transa, &n, &m, &k,
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand Down
65 changes: 65 additions & 0 deletions source/module_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#define DSP_CONNECTOR_H
#ifdef __DSP

#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include "module_hsolver/diag_comm_info.h"

// Base dsp functions
void dspInitHandle(int id);
void dspDestoryHandle(int id);
Expand Down Expand Up @@ -62,5 +66,66 @@ void cgemm_mth_(const char *transa, const char *transb,

//#define zgemm_ zgemm_mt

// The next is dsp utils. It may be moved to other files if this file get too huge

template <typename T>
void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm){

using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_CPU, base_device::DEVICE_CPU>;

auto* swap = new T[notconv * nbase_x];
auto* target = new T[notconv * nbase_x];
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
if (base_device::get_current_precision(swap) == "single")
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_COMPLEX,
MPI_SUM,
0,
diag_comm);
}
else
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_DOUBLE_COMPLEX,
MPI_SUM,
0,
diag_comm);
}

syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);

if (base_device::get_current_precision(swap) == "single")
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_COMPLEX,
MPI_SUM,
0,
diag_comm);
}
else
{
MPI_Reduce(swap,
target,
notconv * nbase_x,
MPI_DOUBLE_COMPLEX,
MPI_SUM,
0,
diag_comm);
}

syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
delete[] swap;
delete[] target;
}


#endif
#endif
52 changes: 52 additions & 0 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,5 +346,57 @@ template struct delete_memory_op<std::complex<float>, base_device::DEVICE_GPU>;
template struct delete_memory_op<std::complex<double>, base_device::DEVICE_GPU>;
#endif

#ifdef __DSP

template <typename FPTYPE>
struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE*& arr, const size_t size, const char* record_in)
{
if (arr != nullptr)
{
free_ht(arr);
}
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK);
std::string record_string;
if (record_in != nullptr)
{
record_string = record_in;
}
else
{
record_string = "no_record";
}

if (record_string != "no_record")
{
ModuleBase::Memory::record(record_string, sizeof(FPTYPE) * size);
}
}
};

template <typename FPTYPE>
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
{
free_ht(arr);
}
};


template struct resize_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;

template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
#endif

} // namespace memory
} // namespace base_device
31 changes: 30 additions & 1 deletion source/module_base/module_device/memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,36 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_GPU>
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM

#ifdef __DSP

template <typename FPTYPE, typename Device>
struct resize_memory_op_mt
{
/// @brief Allocate memory for a given pointer. Note this op will free the pointer first.
///
/// Input Parameters
/// \param dev : the type of computing device
/// \param size : array size
/// \param record_string : label for memory record
///
/// Output Parameters
/// \param arr : allocated array
void operator()(const Device* dev, FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
};

template <typename FPTYPE, typename Device>
struct delete_memory_op_mt
{
/// @brief free memory for multi-device
///
/// Input Parameters
/// \param dev : the type of computing device
/// \param arr : the input array
void operator()(const Device* dev, FPTYPE* arr);
};

#endif // __DSP

} // end of namespace memory
} // end of namespace base_device

Expand Down Expand Up @@ -233,5 +263,4 @@ using castmem_z2c_d2h_op = base_device::memory::

static base_device::DEVICE_CPU* cpu_ctx = {};
static base_device::DEVICE_GPU* gpu_ctx = {};

#endif // MODULE_DEVICE_MEMORY_H_
1 change: 1 addition & 0 deletions source/module_base/module_device/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace base_device

struct DEVICE_CPU;
struct DEVICE_GPU;
struct DEVICE_DSP;

enum AbacusDevice_t
{
Expand Down
9 changes: 8 additions & 1 deletion source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "module_base/timer.h"
#include "module_hsolver/kernels/dngvd_op.h"
#include "module_hsolver/kernels/math_kernel_op.h"
#include "module_base/kernels/dsp/dsp_connector.h"

#include <vector>

Expand Down Expand Up @@ -182,7 +183,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);

#ifdef __DSP
gemm_op_mt<T, Device>()
gemm_op_mt<T, Device>() // In order to not coding another whole template, using this method to minimize the code change.
#else
gemm_op<T, Device>()
#endif
Expand Down Expand Up @@ -444,7 +445,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
#ifdef __MPI
if (this->diag_comm.nproc > 1)
{
#ifdef __DSP
// Only on dsp hardware need an extra space to reduce data
dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm);
#else
auto* swap = new T[notconv * this->nbase_x];

syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);

if (std::is_same<T, double>::value)
Expand Down Expand Up @@ -499,6 +505,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
}
}
delete[] swap;
#endif
}
#endif

Expand Down
10 changes: 10 additions & 0 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,22 @@ class Diago_DavSubspace

bool test_exit_cond(const int& ntry, const int& notconv, const bool& scf);

#ifdef __DSP
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
#else
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
#endif
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;

#ifdef __DSP
using resmem_real_op = base_device::memory::resize_memory_op_mt<Real, Device>;
using delmem_real_op = base_device::memory::delete_memory_op_mt<Real, Device>;
#else
using resmem_real_op = base_device::memory::resize_memory_op<Real, Device>;
using delmem_real_op = base_device::memory::delete_memory_op<Real, Device>;
#endif
using setmem_real_op = base_device::memory::set_memory_op<Real, Device>;

using resmem_real_h_op = base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>;
Expand Down
8 changes: 7 additions & 1 deletion source/module_psi/psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ class Psi

bool allocate_inside = true; ///<whether allocate psi inside Psi class

using set_memory_op = base_device::memory::set_memory_op<T, Device>;
#ifdef __DSP
using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;
using resize_memory_op = base_device::memory::resize_memory_op_mt<T, Device>;
#else
using delete_memory_op = base_device::memory::delete_memory_op<T, Device>;
using resize_memory_op = base_device::memory::resize_memory_op<T, Device>;
#endif
using set_memory_op = base_device::memory::set_memory_op<T, Device>;
using synchronize_memory_op = base_device::memory::synchronize_memory_op<T, Device, Device>;

};

} // end of namespace psi
Expand Down
Loading