44
55using namespace hamilt ;
66
7-
8- template <typename T, typename Device>
9- Operator<T, Device>::Operator(){}
10-
11- template <typename T, typename Device>
12- Operator<T, Device>::~Operator ()
7+ template <typename T, typename Device>
8+ Operator<T, Device>::Operator()
139{
14- if (this ->hpsi != nullptr ) { delete this ->hpsi ;
1510}
11+
12+ template <typename T, typename Device>
13+ Operator<T, Device>::~Operator ()
14+ {
15+ if (this ->hpsi != nullptr )
16+ {
17+ delete this ->hpsi ;
18+ }
1619 Operator* last = this ->next_op ;
1720 Operator* last_sub = this ->next_sub_op ;
18- while (last != nullptr || last_sub != nullptr )
21+ while (last != nullptr || last_sub != nullptr )
1922 {
20- if (last_sub != nullptr )
21- {// delete sub_chain first
23+ if (last_sub != nullptr )
24+ { // delete sub_chain first
2225 Operator* node_delete = last_sub;
2326 last_sub = last_sub->next_sub_op ;
2427 node_delete->next_sub_op = nullptr ;
2528 delete node_delete;
2629 }
2730 else
28- {// delete main chain if sub_chain is deleted
31+ { // delete main chain if sub_chain is deleted
2932 Operator* node_delete = last;
3033 last_sub = last->next_sub_op ;
3134 node_delete->next_sub_op = nullptr ;
@@ -36,7 +39,7 @@ Operator<T, Device>::~Operator()
3639 }
3740}
3841
39- template <typename T, typename Device>
42+ template <typename T, typename Device>
4043typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& input) const
4144{
4245 using syncmem_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
@@ -46,37 +49,51 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
4649
4750 T* tmhpsi = this ->get_hpsi (input);
4851 const T* tmpsi_in = std::get<0 >(psi_info);
49- // if range in hpsi_info is illegal, the first return of to_range() would be nullptr
52+ // if range in hpsi_info is illegal, the first return of to_range() would be nullptr
5053 if (tmpsi_in == nullptr )
5154 {
5255 ModuleBase::WARNING_QUIT (" Operator" , " please choose correct range of psi for hPsi()!" );
5356 }
54- // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
57+ // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
5558 T* hpsi_pointer = std::get<2 >(input);
5659 if (this ->in_place )
5760 {
5861 // ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
5962 syncmem_op ()(this ->ctx , this ->ctx , hpsi_pointer, this ->hpsi ->get_pointer (), this ->hpsi ->size ());
6063 delete this ->hpsi ;
61- this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1 , nbands / psi_input->npol );
64+ this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
65+ 1 ,
66+ nbands / psi_input->npol ,
67+ psi_input->get_nbasis (),
68+ psi_input->get_nbasis (),
69+ true );
6270 }
6371
6472 auto call_act = [&, this ](const Operator* op, const bool & is_first_node) -> void {
65-
6673 // a "psi" with the bands of needed range
67- psi::Psi<T, Device> psi_wrapper (const_cast <T*>(tmpsi_in), 1 , nbands, psi_input->get_nbasis (), true );
68-
69-
74+ psi::Psi<T, Device> psi_wrapper (const_cast <T*>(tmpsi_in),
75+ 1 ,
76+ nbands,
77+ psi_input->get_nbasis (),
78+ psi_input->get_nbasis (),
79+ true );
80+
7081 switch (op->get_act_type ())
7182 {
7283 case 2 :
7384 op->act (psi_wrapper, *this ->hpsi , nbands);
7485 break ;
7586 default :
76- op->act (nbands, psi_input->get_nbasis (), psi_input->npol , tmpsi_in, this ->hpsi ->get_pointer (), psi_input->get_ngk (op->ik ), is_first_node);
87+ op->act (nbands,
88+ psi_input->get_nbasis (),
89+ psi_input->npol ,
90+ tmpsi_in,
91+ this ->hpsi ->get_pointer (),
92+ psi_input->get_current_nbas (),
93+ is_first_node);
7794 break ;
7895 }
79- };
96+ };
8097
8198 ModuleBase::timer::tick (" Operator" , " hPsi" );
8299 call_act (this , true ); // first node
@@ -91,39 +108,43 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
91108 return hpsi_info (this ->hpsi , psi::Range (1 , 0 , 0 , nbands / psi_input->npol ), hpsi_pointer);
92109}
93110
94-
95- template <typename T, typename Device>
96- void Operator<T, Device>::init(const int ik_in)
111+ template <typename T, typename Device>
112+ void Operator<T, Device>::init(const int ik_in)
97113{
98114 this ->ik = ik_in;
99- if (this ->next_op != nullptr ) {
115+ if (this ->next_op != nullptr )
116+ {
100117 this ->next_op ->init (ik_in);
101118 }
102119}
103120
104- template <typename T, typename Device>
105- void Operator<T, Device>::add(Operator* next)
121+ template <typename T, typename Device>
122+ void Operator<T, Device>::add(Operator* next)
106123{
107- if (next==nullptr ) { return ;
108- }
124+ if (next == nullptr )
125+ {
126+ return ;
127+ }
109128 next->is_first_node = false ;
110- if (next->next_op != nullptr ) { this ->add (next->next_op );
111- }
129+ if (next->next_op != nullptr )
130+ {
131+ this ->add (next->next_op );
132+ }
112133 Operator* last = this ;
113- // loop to end of the chain
114- while (last->next_op != nullptr )
134+ // loop to end of the chain
135+ while (last->next_op != nullptr )
115136 {
116- if (next->cal_type == last->cal_type )
137+ if (next->cal_type == last->cal_type )
117138 {
118139 break ;
119140 }
120141 last = last->next_op ;
121142 }
122- if (next->cal_type == last->cal_type )
143+ if (next->cal_type == last->cal_type )
123144 {
124- // insert next to sub chain of current node
145+ // insert next to sub chain of current node
125146 Operator* sub_last = last;
126- while (sub_last->next_sub_op != nullptr )
147+ while (sub_last->next_sub_op != nullptr )
127148 {
128149 sub_last = sub_last->next_sub_op ;
129150 }
@@ -136,34 +157,45 @@ void Operator<T, Device>::add(Operator* next)
136157 }
137158}
138159
139- template <typename T, typename Device>
160+ template <typename T, typename Device>
140161T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
141162{
142163 const int nbands_range = (std::get<1 >(info).range_2 - std::get<1 >(info).range_1 + 1 );
143- // in_place call of hPsi, hpsi inputs as new psi,
144- // create a new hpsi and delete old hpsi later
164+ // in_place call of hPsi, hpsi inputs as new psi,
165+ // create a new hpsi and delete old hpsi later
145166 T* hpsi_pointer = std::get<2 >(info);
146167 const T* psi_pointer = std::get<0 >(info)->get_pointer ();
147- if (this ->hpsi != nullptr )
168+ if (this ->hpsi != nullptr )
148169 {
149170 delete this ->hpsi ;
150171 this ->hpsi = nullptr ;
151172 }
152- if (!hpsi_pointer)
173+ if (!hpsi_pointer)
153174 {
154175 ModuleBase::WARNING_QUIT (" Operator::hPsi" , " hpsi_pointer can not be nullptr" );
155176 }
156- else if (hpsi_pointer == psi_pointer)
177+ else if (hpsi_pointer == psi_pointer)
157178 {
158179 this ->in_place = true ;
159- this ->hpsi = new psi::Psi<T, Device>(std::get<0 >(info)[0 ], 1 , nbands_range);
180+ // this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
181+ this ->hpsi = new psi::Psi<T, Device>(1 ,
182+ nbands_range,
183+ std::get<0 >(info)->get_nbasis (),
184+ std::get<0 >(info)->get_nbasis (),
185+ true );
160186 }
161187 else
162188 {
163189 this ->in_place = false ;
164- this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0 >(info)[0 ], 1 , nbands_range);
190+
191+ this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
192+ 1 ,
193+ nbands_range,
194+ std::get<0 >(info)->get_nbasis (),
195+ std::get<0 >(info)->get_nbasis (),
196+ true );
165197 }
166-
198+
167199 hpsi_pointer = this ->hpsi ->get_pointer ();
168200 size_t total_hpsi_size = nbands_range * this ->hpsi ->get_nbasis ();
169201 // ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
@@ -172,7 +204,8 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
172204 return hpsi_pointer;
173205}
174206
175- namespace hamilt {
207+ namespace hamilt
208+ {
176209template class Operator <float , base_device::DEVICE_CPU>;
177210template class Operator <std::complex <float >, base_device::DEVICE_CPU>;
178211template class Operator <double , base_device::DEVICE_CPU>;
@@ -183,4 +216,4 @@ template class Operator<std::complex<float>, base_device::DEVICE_GPU>;
183216template class Operator <double , base_device::DEVICE_GPU>;
184217template class Operator <std::complex <double >, base_device::DEVICE_GPU>;
185218#endif
186- }
219+ } // namespace hamilt
0 commit comments