Skip to content

Commit 86bd28b

Browse files
dyzhengdyzheng
andauthored
Refactor: new memory record interface (#1794)
* Refactor: new memory record interface * Fix: CUDA and ROCM compiler * Fix: UT related to memory.cpp Co-authored-by: dyzheng <[email protected]>
1 parent 6bf1ab7 commit 86bd28b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+369
-216
lines changed

source/module_base/math_ylmreal.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cassert>
77
#include "ylm.h"
88
#include "module_base/kernels/math_op.h"
9+
#include "module_psi/kernels/memory_op.h"
910

1011
namespace ModuleBase
1112
{
@@ -323,7 +324,7 @@ void YlmReal::Ylm_Real(Device * ctx, const int lmax2, const int ng, const FPTYPE
323324
ModuleBase::WARNING_QUIT("YLM_REAL","l>30 or l<0");
324325
}
325326
FPTYPE * p = nullptr, * phi = nullptr, * cost = nullptr;
326-
resmem_var_op()(ctx, p, (lmax + 1) * (lmax + 1) * ng);
327+
resmem_var_op()(ctx, p, (lmax + 1) * (lmax + 1) * ng, "YlmReal::Ylm_Real");
327328

328329
cal_ylm_real_op()(
329330
ctx,

source/module_base/math_ylmreal.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include "vector3.h"
55
#include "matrix.h"
6-
#include "module_psi/psi.h"
76

87
namespace ModuleBase
98
{

source/module_base/memory.cpp

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// DATE : 2008-11-18
44
//==========================================================
55
#include "memory.h"
6+
#include "global_variable.h"
7+
#include "src_parallel/parallel_reduce.h"
68

79
namespace ModuleBase
810
{
@@ -18,7 +20,7 @@ int Memory::bool_memory = sizeof(bool); // 1.0 Byte
1820
int Memory::float_memory = sizeof(float); // 4.0 Byte
1921
int Memory::short_memory = sizeof(short); // 2.0 Byte
2022

21-
int Memory::n_memory = 500;
23+
int Memory::n_memory = 1000;
2224
int Memory::n_now = 0;
2325
bool Memory::init_flag = false;
2426

@@ -140,10 +142,75 @@ double Memory::record
140142
return consume[find];
141143
}
142144

145+
void Memory::record
146+
(
147+
const std::string &name_in,
148+
const size_t &n_in,
149+
const bool accumulate
150+
)
151+
{
152+
if(!Memory::init_flag)
153+
{
154+
name = new std::string[n_memory];
155+
class_name = new std::string[n_memory];
156+
consume = new double[n_memory];
157+
for(int i=0;i<n_memory;i++)
158+
{
159+
consume[i] = 0.0;
160+
}
161+
Memory::init_flag = true;
162+
}
163+
164+
int find = 0;
165+
for(find = 0; find < n_now; find++)
166+
{
167+
if( name_in == name[find] )
168+
{
169+
break;
170+
}
171+
}
172+
173+
// find == n_now : found a new record.
174+
if(find == n_now)
175+
{
176+
n_now++;
177+
name[find] = name_in;
178+
class_name[find] = "";
179+
}
180+
if(n_now >= n_memory)
181+
{
182+
std::cout<<" Error! Too many memories has been recorded.";
183+
return;
184+
}
185+
186+
const double factor = 1.0/1024.0/1024.0;
187+
double size_mb = n_in * factor;
188+
189+
if(accumulate)
190+
{
191+
consume[find] += size_mb;
192+
Memory::total += size_mb;
193+
}
194+
else
195+
{
196+
if(consume[find] < size_mb)
197+
{
198+
Memory::total += size_mb - consume[find];
199+
consume[find] = size_mb;
200+
if(consume[find] > 5)
201+
{
202+
print(find);
203+
}
204+
}
205+
}
206+
207+
return;
208+
}
209+
143210
void Memory::print(const int find)
144211
{
145-
// std::cout <<"\n Warning_Memory_Consuming : "
146-
// <<class_name[find]<<" "<<name[find]<<" "<<consume[find]<<" MB" << std::endl;
212+
GlobalV::ofs_running <<"\n Warning_Memory_Consuming allocated: "
213+
<<" "<<name[find]<<" "<<consume[find]<<" MB" << std::endl;
147214
return;
148215
}
149216

@@ -167,10 +234,12 @@ void Memory::print_all(std::ofstream &ofs)
167234
if(!init_flag) return;
168235

169236
const double small = 1.0;
170-
// std::cout<<"\n CLASS_NAME---------|NAME---------------|MEMORY(MB)--------";
171-
ofs <<"\n CLASS_NAME---------|NAME---------------|MEMORY(MB)--------" << std::endl;
237+
#ifdef __MPI
238+
Parallel_Reduce::reduce_double_all(Memory::total);
239+
#endif
240+
ofs <<"\n NAME---------------|MEMORY(MB)--------" << std::endl;
172241
// std::cout<<"\n"<<std::setw(41)<< " " <<std::setprecision(4)<<total;
173-
ofs <<std::setw(41)<< " " <<std::setprecision(4)<<total << std::endl;
242+
ofs <<std::setw(20)<< "total" << std::setw(15) <<std::setprecision(4)<< Memory::total << std::endl;
174243

175244
bool *print_flag = new bool[n_memory];
176245
for(int i=0; i<n_memory; i++) print_flag[i] = false;
@@ -192,16 +261,16 @@ void Memory::print_all(std::ofstream &ofs)
192261
}
193262
}
194263
print_flag[k] = true;
195-
264+
#ifdef __MPI
265+
Parallel_Reduce::reduce_double_all(consume[k]);
266+
#endif
196267
if ( consume[k] < small )
197268
{
198269
continue;
199270
}
200271
else
201272
{
202-
ofs << " "
203-
<< std::setw(20) << class_name[k]
204-
<< std::setw(20) << name[k]
273+
ofs << std::setw(20) << name[k]
205274
<< std::setw(15) << consume[k] << std::endl;
206275

207276
// std::cout << "\n "
@@ -211,6 +280,7 @@ void Memory::print_all(std::ofstream &ofs)
211280
}
212281
}
213282
// std::cout<<"\n ----------------------------------------------------------"<<std::endl;
283+
ofs<<" ------------- < 1.0 MB has been ignored ----------------"<<std::endl;
214284
ofs<<" ----------------------------------------------------------"<<std::endl;
215285
delete[] print_flag; //mohan fix by valgrind at 2012-04-02
216286
return;

source/module_base/memory.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ class Memory
3737
const std::string &type,
3838
const bool accumulate = false);
3939

40+
/**
41+
* @brief Record memory consumed during computation
42+
*
43+
* @param name The name of a quantity
44+
* @param n The number of the quantity
45+
* @param accumulate Useless, always set false
46+
*/
47+
static void record(
48+
const std::string &name_in,
49+
const size_t &n_in,
50+
const bool accumulate = false
51+
);
52+
4053
static double &get_total(void)
4154
{
4255
return total;

source/module_base/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ AddTest(
6868
LIBS ${math_libs} device
6969
SOURCES math_ylmreal_test.cpp ../math_ylmreal.cpp ../complexmatrix.cpp ../global_variable.cpp ../ylm.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ../vector3.h
7070
../../src_parallel/parallel_reduce.cpp ../../src_parallel/parallel_kpoints.cpp ../../src_parallel/parallel_global.cpp ../../src_parallel/parallel_common.cpp
71+
../memory.cpp
7172
)
7273
AddTest(
7374
TARGET base_math_sphbes

source/module_base/test/memory_test.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
#include <fstream>
55
#include <cstdio>
66

7+
namespace GlobalV
8+
{
9+
std::ofstream ofs_running;
10+
}
11+
712
/************************************************
813
* unit test of class Memory
914
***********************************************/

source/module_deepks/test/klist_1.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ namespace Test_Deepks
105105
wk = new double[kpoint_number];
106106
isk = new int[kpoint_number];
107107

108-
ModuleBase::Memory::record("K_Vectors","kvec_c",kpoint_number*3,"double");
109-
ModuleBase::Memory::record("K_Vectors","kvec_d",kpoint_number*3,"double");
110-
ModuleBase::Memory::record("K_Vectors","wk",kpoint_number*3,"double");
111-
ModuleBase::Memory::record("K_Vectors","isk",kpoint_number*3,"int");
108+
ModuleBase::Memory::record("KV::kvec_c",sizeof(double) * kpoint_number*3);
109+
ModuleBase::Memory::record("KV::kvec_d",sizeof(double) * kpoint_number*3);
110+
ModuleBase::Memory::record("KV::wk",sizeof(double) * kpoint_number*3);
111+
ModuleBase::Memory::record("KV::isk",sizeof(int) * kpoint_number*3);
112112

113113
return;
114114
}

source/module_deepks/test/parallel_orbitals.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ namespace Test_Deepks
3535
trace_loc_col[i] = -1;
3636
}
3737

38-
ModuleBase::Memory::record("Parallel_Orbitals","trace_loc_row",GlobalV::NLOCAL,"int");
39-
ModuleBase::Memory::record("Parallel_Orbitals","trace_loc_col",GlobalV::NLOCAL,"int");
38+
ModuleBase::Memory::record("PO::trace_loc_row",sizeof(int) * GlobalV::NLOCAL);
39+
ModuleBase::Memory::record("PO::trace_loc_col",sizeof(int) * GlobalV::NLOCAL);
4040

4141
for (int i=0; i<GlobalV::NLOCAL; i++)
4242
{

source/module_dftu/dftu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ void DFTU::init(UnitCell& cell, // unitcell class
211211
}
212212
}
213213

214-
ModuleBase::Memory::record("DFTU", "locale", num_locale, "double");
214+
ModuleBase::Memory::record("DFTU::locale", sizeof(double) * num_locale);
215215
return;
216216
}
217217

source/module_elecstate/elecstate_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ void ElecStatePW<FPTYPE, Device>::init_rho_data()
5151
this->kin_r = reinterpret_cast<FPTYPE **>(this->charge->kin_r);
5252
}
5353
}
54-
resmem_complex_op()(this->ctx, this->wfcr, this->basis->nmaxgr);
55-
resmem_complex_op()(this->ctx, this->wfcr_another_spin, this->charge->nrxx);
54+
resmem_complex_op()(this->ctx, this->wfcr, this->basis->nmaxgr, "ElecSPW::wfcr");
55+
resmem_complex_op()(this->ctx, this->wfcr_another_spin, this->charge->nrxx, "ElecSPW::wfcr_a");
5656
this->init_rho = true;
5757
}
5858

0 commit comments

Comments
 (0)