Skip to content

Commit 0e17595

Browse files
committed
Add INPUT parameter of dsp counts
1 parent 4bda3c2 commit 0e17595

File tree

5 files changed

+17
-11
lines changed

5 files changed

+17
-11
lines changed

source/source_base/module_device/memory_op.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#ifdef __DSP
66
#include "source_base/kernels/dsp/dsp_connector.h"
77
#include "source_base/global_variable.h"
8+
#include "source_io/module_parameter/parameter.h"
89
#endif
910

1011
#include <complex>
@@ -452,7 +453,7 @@ struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
452453
{
453454
mtfunc::free_ht(arr);
454455
}
455-
arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK % 4);
456+
arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK % PARAM.inp.dsp_count);
456457
std::string record_string;
457458
if (record_in != nullptr)
458459
{

source/source_base/module_external/blas_connector_matrix.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#ifdef __DSP
55
#include "source_base/kernels/dsp/dsp_connector.h"
66
#include "source_base/global_variable.h"
7+
#include "source_io/module_parameter/parameter.h"
78
#endif
89

910
#ifdef __CUDA
@@ -30,7 +31,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
3031
else if (device_type == base_device::AbacusDevice_t::DspDevice){
3132
mtfunc::sgemm_mth_(&transb, &transa, &n, &m, &k,
3233
&alpha, b, &ldb, a, &lda,
33-
&beta, c, &ldc, GlobalV::MY_RANK % 4);
34+
&beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
3435
}
3536
#endif
3637
#ifdef __CUDA
@@ -67,7 +68,7 @@ void BlasConnector::gemm(const char transa,
6768
#ifdef __DSP
6869
else if (device_type == base_device::AbacusDevice_t::DspDevice)
6970
{
70-
mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4);
71+
mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
7172
}
7273
#endif
7374
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -106,7 +107,7 @@ void BlasConnector::gemm(const char transa,
106107
#ifdef __DSP
107108
else if (device_type == base_device::AbacusDevice_t::DspDevice)
108109
{
109-
mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4);
110+
mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
110111
}
111112
#endif
112113
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -157,7 +158,7 @@ void BlasConnector::gemm(const char transa,
157158
#ifdef __DSP
158159
else if (device_type == base_device::AbacusDevice_t::DspDevice)
159160
{
160-
mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4);
161+
mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
161162
}
162163
#endif
163164
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -200,7 +201,7 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
200201
else if (device_type == base_device::AbacusDevice_t::DspDevice){
201202
mtfunc::sgemm_mth_(&transb, &transa, &m, &n, &k,
202203
&alpha, a, &lda, b, &ldb,
203-
&beta, c, &ldc, GlobalV::MY_RANK % 4);
204+
&beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
204205
}
205206
#endif
206207
#ifdef __CUDA
@@ -237,7 +238,7 @@ void BlasConnector::gemm_cm(const char transa,
237238
#ifdef __DSP
238239
else if (device_type == base_device::AbacusDevice_t::DspDevice)
239240
{
240-
mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4);
241+
mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
241242
}
242243
#endif
243244
#ifdef __CUDA
@@ -276,7 +277,7 @@ void BlasConnector::gemm_cm(const char transa,
276277
#ifdef __DSP
277278
else if (device_type == base_device::AbacusDevice_t::DspDevice)
278279
{
279-
mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4);
280+
mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
280281
}
281282
#endif
282283
#ifdef __CUDA
@@ -327,7 +328,7 @@ void BlasConnector::gemm_cm(const char transa,
327328
#ifdef __DSP
328329
else if (device_type == base_device::AbacusDevice_t::DspDevice)
329330
{
330-
mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4);
331+
mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
331332
}
332333
#endif
333334
#ifdef __CUDA

source/source_base/module_fft/fft_dsp.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "source_base/global_variable.h"
44
#include "source_base/global_function.h"
5+
#include "source_io/module_parameter/parameter.h"
56

67
#include <iostream>
78
#include <string.h>
@@ -14,7 +15,7 @@ void FFT_DSP<double>::initfft(int nx_in, int ny_in, int nz_in)
1415
this->nx = nx_in;
1516
this->ny = ny_in;
1617
this->nz = nz_in;
17-
cluster_id = GlobalV::MY_RANK % 4;
18+
cluster_id = GlobalV::MY_RANK % PARAM.inp.dsp_count;
1819
nxyz = this->nx * this->ny * this->nz;
1920
}
2021
template <>

source/source_io/module_parameter/input_parameter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,5 +692,8 @@ struct Input_para
692692
bool of_cd = false; ///< add CD potential or not https://doi.org/10.1103/PhysRevB.98.144302
693693
double of_mCD_alpha = 1.0; /// parameter of modified CD Potential
694694

695+
// ============== #Parameters (25.uncommon hardware) =================
696+
int dsp_count = 4; /// the count of dsp hardwares in one node
697+
695698
};
696699
#endif

source/source_main/driver_run.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ void Driver::init_hardware()
129129

130130
#ifdef __DSP
131131
std::cout << " ** Initializing DSP Hardware..." << std::endl;
132-
mtfunc::dspInitHandle(GlobalV::MY_RANK % 4);
132+
mtfunc::dspInitHandle(GlobalV::MY_RANK % PARAM.inp.dsp_count);
133133
#endif
134134
}
135135

0 commit comments

Comments
 (0)