55#include " source_base/timer.h"
66
77#include < complex>
8+ #include < algorithm>
9+ #include < type_traits>
810#include < stdexcept>
911
1012#include " source_base/tool_quit.h"
1517
1618using namespace ROPTLITE ;
1719
20+ namespace
21+ {
22+ template <typename ValueType>
23+ inline ValueType convert_complex_entry (const std::complex <double >& entry, std::true_type)
24+ {
25+ return static_cast <ValueType>(entry.real ());
26+ }
27+
28+ template <typename ValueType>
29+ inline ValueType convert_complex_entry (const std::complex <double >& entry, std::false_type)
30+ {
31+ return static_cast <ValueType>(entry);
32+ }
33+ } // namespace
34+
1835namespace ModuleESolver
1936{
2037
@@ -24,6 +41,8 @@ RDMFT_LCAO<TK, TR>::RDMFT_LCAO(UnitCell& ucell, const Input_para& inp)
2441 // this->classname = "RDMFT_LCAO";
2542 // this->basisname = "LCAO";
2643
44+ this ->ucell_ref_ = &ucell;
45+
2746#ifdef __EXX
2847 // 1. currently this initialization must be put in constructor rather than `before_all_runners()`
2948 // because the latter is not reused by ESolver_LCAO_TDDFT,
@@ -138,9 +157,23 @@ RDMFT_LCAO<TK, TR>::~RDMFT_LCAO()
138157template <typename TK, typename TR>
139158realdp RDMFT_LCAO<TK, TR>::f(const Variable& x) const
140159{
160+ // convert Variable x to occupation numbers and wavefunctions
161+ ModuleBase::matrix occ_num;
162+ psi::Psi<TK> wfc;
163+
164+ VariableToOccNumWfc (x, occ_num, wfc);
165+
166+ if (this ->ucell_ref_ == nullptr )
167+ {
168+ ModuleBase::WARNING (" RDMFT_LCAO" , " UnitCell reference is unavailable before calling update_elec" );
169+ throw std::runtime_error (" RDMFT_LCAO::f requires a valid UnitCell reference" );
170+ }
171+
172+ this ->rdmft_solver .update_elec (*this ->ucell_ref_ , occ_num, wfc);
173+
141174 this ->rdmft_solver .cal_Energy (1 ); // 1 means not to calculate forces here
142175
143- std::cout << " Called f() in RDMFT_LCAO, Etotal = " << this ->rdmft_solver .Etotal << std::endl;
176+ // std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_solver.Etotal << std::endl;
144177 return this ->rdmft_solver .Etotal ;
145178}
146179
@@ -173,6 +206,105 @@ Vector& RDMFT_LCAO<TK, TR>::EucHessianEta(const Variable& x, const Vector& etax,
173206 return *result;
174207}
175208
209+ template <typename TK, typename TR>
210+ void RDMFT_LCAO<TK, TR>::VariableToOccNumWfc(const Variable& x,
211+ ModuleBase::matrix& occ_num,
212+ psi::Psi<TK>& wfc) const
213+ {
214+ const int required_components = this ->nk_total ;
215+ const int expected_elements = required_components * 2 ;
216+ if (x.Getnumofelements () != expected_elements)
217+ {
218+ throw std::runtime_error (" RDMFT_LCAO::VariableToOccNumWfc received a variable with unexpected component count" );
219+ }
220+
221+ if (occ_num.nr != this ->nk_total || occ_num.nc != this ->nbands_total )
222+ {
223+ occ_num.create (this ->nk_total , this ->nbands_total , false );
224+ }
225+
226+ const int local_basis_count = this ->pv .get_row_size ();
227+ const int local_band_count = this ->pv .ncol_bands ;
228+ const int available_band_slots = this ->pv .get_col_size ();
229+ const int band_limit = std::min (local_band_count, available_band_slots);
230+
231+ if (wfc.get_nk () != this ->nk_total || wfc.get_nbands () != local_band_count || wfc.get_nbasis () != local_basis_count)
232+ {
233+ wfc.resize (this ->nk_total , local_band_count, local_basis_count);
234+ }
235+
236+ // Occupation numbers occupy the second half of the product element.
237+ const int occ_offset = required_components;
238+ for (int ik = 0 ; ik < this ->nk_total ; ++ik)
239+ {
240+ const Element& occ_elem = x.GetElement (occ_offset + ik);
241+ if (occ_elem.Getrow () != this ->nbands_total || occ_elem.Getcol () != 1 )
242+ {
243+ throw std::runtime_error (" RDMFT_LCAO::VariableToOccNumWfc encountered an occupation component with unexpected shape" );
244+ }
245+
246+ const double * occ_values = occ_elem.ObtainReadData ();
247+ for (int ib = 0 ; ib < this ->nbands_total ; ++ib)
248+ {
249+ occ_num (ik, ib) = occ_values[ib];
250+ }
251+ }
252+
253+ // Wavefunction blocks sit in the first half of the product element.
254+ for (int ik = 0 ; ik < this ->nk_total ; ++ik)
255+ {
256+ const Element& wf_elem = x.GetElement (ik);
257+ const int global_rows = wf_elem.Getrow ();
258+ const int global_cols = wf_elem.Getcol ();
259+
260+ if (global_rows != this ->nbasis_total || global_cols != this ->nbands_total )
261+ {
262+ throw std::runtime_error (" RDMFT_LCAO::VariableToOccNumWfc encountered a wavefunction block with unexpected shape" );
263+ }
264+
265+ const double * raw_data = wf_elem.ObtainReadData ();
266+ const bool is_complex = wf_elem.Getiscomplex ();
267+
268+ wfc.fix_k (ik);
269+
270+ for (int ib_local = 0 ; ib_local < band_limit; ++ib_local)
271+ {
272+ const int band_global = this ->pv .local2global_col (ib_local);
273+ if (band_global < 0 || band_global >= global_cols)
274+ {
275+ continue ;
276+ }
277+
278+ wfc.fix_b (ib_local);
279+
280+ for (int ir_local = 0 ; ir_local < local_basis_count; ++ir_local)
281+ {
282+ const int basis_global = this ->pv .local2global_row (ir_local);
283+ if (basis_global < 0 || basis_global >= global_rows)
284+ {
285+ continue ;
286+ }
287+
288+ const int linear_index = basis_global + global_rows * band_global;
289+
290+ TK value;
291+ if (is_complex)
292+ {
293+ const auto * complex_data = reinterpret_cast <const std::complex <double >*>(raw_data);
294+ const std::complex <double > entry = complex_data[linear_index];
295+ value = convert_complex_entry<TK>(entry, typename std::is_same<TK, double >::type ());
296+ }
297+ else
298+ {
299+ value = static_cast <TK>(raw_data[linear_index]);
300+ }
301+
302+ wfc (ir_local) = value;
303+ }
304+ }
305+ }
306+ }
307+
176308
177309
178310
0 commit comments