Skip to content

Commit a0e7a02

Browse files
committed
prepare for PR
2 parents 909f291 + 025b199 commit a0e7a02

File tree

11 files changed

+41
-40
lines changed

11 files changed

+41
-40
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ ${OBJS_DELTASPIN}\
114114
${OBJS_TENSOR}\
115115
${OBJS_HSOLVER_PEXSI}\
116116
${OBJS_LR}\
117-
${OBJS_RDMFT}\
117+
${OBJS_RDMFT}
118118

119119
OBJS_MAIN=main.o\
120120
driver.o\

source/module_base/blas_connector.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#ifdef __DSP
44
#include "module_base/kernels/dsp/dsp_connector.h"
5+
#include "module_base/global_variable.h"
56
#endif
67

78
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
@@ -94,7 +95,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
9495
else if (device_type == base_device::AbacusDevice_t::DspDevice){
9596
sgemm_mt_(&transb, &transa, &n, &m, &k,
9697
&alpha, b, &ldb, a, &lda,
97-
&beta, c, &ldc);
98+
&beta, c, &ldc, GlobalV::MY_RANK);
9899
}
99100
#endif
100101
}
@@ -112,7 +113,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
112113
else if (device_type == base_device::AbacusDevice_t::DspDevice){
113114
dgemm_mt_(&transb, &transa, &n, &m, &k,
114115
&alpha, b, &ldb, a, &lda,
115-
&beta, c, &ldc);
116+
&beta, c, &ldc, GlobalV::MY_RANK);
116117
}
117118
#endif
118119
}
@@ -130,7 +131,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
130131
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
131132
cgemm_mt_(&transb, &transa, &n, &m, &k,
132133
&alpha, b, &ldb, a, &lda,
133-
&beta, c, &ldc);
134+
&beta, c, &ldc, GlobalV::MY_RANK);
134135
}
135136
#endif
136137
}
@@ -148,7 +149,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
148149
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
149150
zgemm_mt_(&transb, &transa, &n, &m, &k,
150151
&alpha, b, &ldb, a, &lda,
151-
&beta, c, &ldc);
152+
&beta, c, &ldc, GlobalV::MY_RANK);
152153
}
153154
#endif
154155
}

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
// Base dsp functions
66
void dspInitHandle(int id);
7-
void dspDestoryHandle();
8-
void *malloc_ht(size_t bytes);
7+
void dspDestoryHandle(int id);
8+
void *malloc_ht(size_t bytes, int cluster_id);
99
void free_ht(void* ptr);
1010

1111

@@ -15,50 +15,50 @@ void sgemm_mt_(const char *transa, const char *transb,
1515
const int *m, const int *n, const int *k,
1616
const float *alpha, const float *a, const int *lda,
1717
const float *b, const int *ldb, const float *beta,
18-
float *c, const int *ldc);
18+
float *c, const int *ldc, int cluster_id);
1919

2020
void dgemm_mt_(const char *transa, const char *transb,
2121
const int *m, const int *n, const int *k,
2222
const double *alpha,const double *a, const int *lda,
2323
const double *b, const int *ldb, const double *beta,
24-
double *c, const int *ldc);
24+
double *c, const int *ldc, int cluster_id);
2525

2626
void zgemm_mt_(const char *transa, const char *transb,
2727
const int *m, const int *n, const int *k,
2828
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
2929
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
30-
std::complex<double> *c, const int *ldc);
30+
std::complex<double> *c, const int *ldc, int cluster_id);
3131

3232
void cgemm_mt_(const char *transa, const char *transb,
3333
const int *m, const int *n, const int *k,
3434
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
3535
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
36-
std::complex<float> *c, const int *ldc);
36+
std::complex<float> *c, const int *ldc, int cluster_id);
3737

3838

3939
void sgemm_mth_(const char *transa, const char *transb,
4040
const int *m, const int *n, const int *k,
4141
const float *alpha, const float *a, const int *lda,
4242
const float *b, const int *ldb, const float *beta,
43-
float *c, const int *ldc);
43+
float *c, const int *ldc, int cluster_id);
4444

4545
void dgemm_mth_(const char *transa, const char *transb,
4646
const int *m, const int *n, const int *k,
4747
const double *alpha,const double *a, const int *lda,
4848
const double *b, const int *ldb, const double *beta,
49-
double *c, const int *ldc);
49+
double *c, const int *ldc, int cluster_id);
5050

5151
void zgemm_mth_(const char *transa, const char *transb,
5252
const int *m, const int *n, const int *k,
5353
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
5454
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
55-
std::complex<double> *c, const int *ldc);
55+
std::complex<double> *c, const int *ldc, int cluster_id);
5656

5757
void cgemm_mth_(const char *transa, const char *transb,
5858
const int *m, const int *n, const int *k,
5959
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
6060
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
61-
std::complex<float> *c, const int *ldc);
61+
std::complex<float> *c, const int *ldc, int cluster_id);
6262

6363
//#define zgemm_ zgemm_mt
6464

source/module_base/module_device/memory_op.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "module_base/tool_threading.h"
55
#ifdef __DSP
66
#include "module_base/kernels/dsp/dsp_connector.h"
7+
#include "module_base/global_variable.h"
78
#endif
89

910
#include <complex>
@@ -21,17 +22,9 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
2122
{
2223
if (arr != nullptr)
2324
{
24-
#ifdef __DSP
25-
free_ht(arr);
26-
#else
2725
free(arr);
28-
#endif
2926
}
30-
#ifdef __DSP
31-
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size);
32-
#else
3327
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
34-
#endif
3528
std::string record_string;
3629
if (record_in != nullptr)
3730
{
@@ -103,11 +96,7 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
10396
{
10497
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
10598
{
106-
#ifdef __DSP
107-
free_ht(arr);
108-
#else
10999
free(arr);
110-
#endif
111100
}
112101
};
113102

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
7373
#endif
7474
#ifdef __DSP
7575
std::cout << " ** Initializing DSP Hardware..." << std::endl;
76-
dspInitHandle(GlobalV::MY_RANK % 4);
76+
dspInitHandle(GlobalV::MY_RANK);
7777
#endif
7878
}
7979

@@ -102,7 +102,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
102102
}
103103
#ifdef __DSP
104104
std::cout << " ** Closing DSP Hardware..." << std::endl;
105-
dspDestoryHandle();
105+
dspDestoryHandle(GlobalV::MY_RANK);
106106
#endif
107107
if (PARAM.inp.precision == "single")
108108
{

source/module_io/read_input_item_system.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ void ReadInput::item_system()
156156
{
157157
para.input.symmetry = "-1"; // disable kpoint reduce
158158
}
159+
if (para.input.berry_phase)
160+
{
161+
para.input.symmetry = "-1"; // disable kpoint reduce
162+
}
159163
};
160164
this->add_item(item);
161165
}

source/module_io/test_serial/read_input_item_test.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ TEST_F(InputTest, Item_test)
151151
param.input.qo_switch = true;
152152
it->second.reset_value(it->second, param);
153153
EXPECT_EQ(param.input.symmetry, "-1");
154+
155+
param.input.symmetry = "default";
156+
param.input.berry_phase = true;
157+
it->second.reset_value(it->second, param);
158+
EXPECT_EQ(param.input.symmetry, "-1");
154159
}
155160
{ // nelec
156161
auto it = find_label("nelec", readinput.input_lists);

source/module_rdmft/CMakeLists.txt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
add_library(
2-
rdmft
3-
OBJECT
4-
rdmft.cpp
5-
rdmft_tools.cpp
6-
rdmft_test.cpp
7-
)
1+
if(ENABLE_LCAO)
2+
add_library(
3+
rdmft
4+
OBJECT
5+
rdmft.cpp
6+
rdmft_tools.cpp
7+
rdmft_test.cpp
8+
)
9+
endif()
810

911
# if(ENABLE_COVERAGE)
1012
# add_coverage(psi)

toolchain/build_abacus_gnu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,6 @@ cat << EOF
8282
========================== usage =========================
8383
Done!
8484
To use the installed ABACUS version
85-
You need to source $(pwd)/abacus_env.sh first !
85+
You need to source ${TOOL}/abacus_env.sh first !
8686
"""
8787
EOF

toolchain/build_abacus_intel-mpich.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,6 @@ cat << EOF
7373
========================== usage =========================
7474
Done!
7575
To use the installed ABACUS version
76-
You need to source $(pwd)/abacus_env.sh first !
76+
You need to source ${TOOL}/abacus_env.sh first !
7777
"""
7878
EOF

0 commit comments

Comments
 (0)