44#include < complex>
55#include " module_psi/psi.h"
66#include " module_base/global_function.h"
7+ #include " module_base/tool_quit.h"
78
89namespace hamilt
910{
@@ -33,8 +34,12 @@ class Operator
3334 // this is the core function for Operator
3435 // do H|psi> from input |psi> ,
3536 // output of hpsi would be first member of the returned tuple
36- typedef std::tuple<const psi::Psi<T>*, const psi::Range> hpsi_info;
37- virtual hpsi_info hPsi (const hpsi_info& input)const {return hpsi_info (nullptr , 0 );}
37+ typedef std::tuple<const psi::Psi<T>*, const psi::Range, T*> hpsi_info;
38+ virtual hpsi_info hPsi (hpsi_info& input)const
39+ {
40+ ModuleBase::WARNING_QUIT (" Operator::hPsi" , " hPsi error!" );
41+ return hpsi_info (nullptr , 0 , nullptr );
42+ }
3843
3944 virtual void init (const int ik_in)
4045 {
@@ -65,7 +70,7 @@ class Operator
6570 protected:
6671 int ik = 0 ;
6772
68- mutable bool recursive = false ;
73+ mutable bool in_place = false ;
6974
7075 // calculation type, only different type can be in main chain table
7176 int cal_type = 0 ;
@@ -74,30 +79,39 @@ class Operator
7479 // if this Operator is first node in chain table, hpsi would not be empty
7580 mutable psi::Psi<T>* hpsi = nullptr ;
7681
82+ /* This function would analyze hpsi_info and choose how to arrange hpsi storage
83+ In hpsi_info, if the third parameter hpsi_pointer is set, which indicates memory of hpsi is arranged by developer;
84+ if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare.
85+ two cases would occurred:
86+ 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary memory
87+ 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case
88+ */
7789 T* get_hpsi (const hpsi_info& info)const
7890 {
7991 const int nbands_range = (std::get<1 >(info).range_2 - std::get<1 >(info).range_1 + 1 );
80- // recursive call of hPsi, hpsi inputs as new psi,
92+ // in_place call of hPsi, hpsi inputs as new psi,
8193 // create a new hpsi and delete old hpsi later
82- if (this ->hpsi != std::get<0 >(info) )
94+ T* hpsi_pointer = std::get<2 >(info);
95+ const T* psi_pointer = std::get<0 >(info)->get_pointer ();
96+ if (!hpsi_pointer)
8397 {
84- this ->recursive = false ;
85- if (this ->hpsi != nullptr )
86- {
87- delete this ->hpsi ;
88- }
98+ ModuleBase::WARNING_QUIT (" Operator::hPsi" , " hpsi_pointer can not be nullptr" );
99+ }
100+ else if (hpsi_pointer == psi_pointer)
101+ {
102+ this ->in_place = true ;
103+ this ->hpsi = new psi::Psi<T>(std::get<0 >(info)[0 ], 1 , nbands_range);
89104 }
90105 else
91106 {
92- this ->recursive = true ;
107+ this ->in_place = false ;
108+ this ->hpsi = new psi::Psi<T>(hpsi_pointer, std::get<0 >(info)[0 ], 1 , nbands_range);
93109 }
94- // create a new hpsi
95- this ->hpsi = new psi::Psi<T>(std::get<0 >(info)[0 ], 1 , nbands_range);
96110
97- T* pointer_hpsi = this ->hpsi ->get_pointer ();
111+ hpsi_pointer = this ->hpsi ->get_pointer ();
98112 size_t total_hpsi_size = nbands_range * this ->hpsi ->get_nbasis ();
99- ModuleBase::GlobalFunc::ZEROS (pointer_hpsi , total_hpsi_size);
100- return pointer_hpsi ;
113+ ModuleBase::GlobalFunc::ZEROS (hpsi_pointer , total_hpsi_size);
114+ return hpsi_pointer ;
101115 }
102116};
103117
0 commit comments