@@ -532,17 +532,17 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
532532 // to G space. maybe need fftw with OpenMP
533533 rho_basis->real2recip (aux, aux);
534534
535- // =============== GPU/CPU异构优化:使用sincos op替换原循环 ===============
535+ // sincos op for G space
536536
537537
538- // 准备原子相关数据:按照全局原子索引iat顺序
538+ // data preparation
539539 std::vector<FPTYPE> tau_flat (this ->nat * 3 );
540540 std::vector<FPTYPE> gcar_flat (rho_basis->npw * 3 );
541541
542- // 按照原始代码逻辑:遍历全局原子索引iat,通过查表获取(it,ia)
542+
543543 for (int iat = 0 ; iat < this ->nat ; iat++) {
544- int it = ucell.iat2it [iat]; // 查表获取原子类型
545- int ia = ucell.iat2ia [iat]; // 查表获取该类型内的原子索引
544+ int it = ucell.iat2it [iat];
545+ int ia = ucell.iat2ia [iat];
546546
547547 tau_flat[iat * 3 + 0 ] = static_cast <FPTYPE>(ucell.atoms [it].tau [ia][0 ]);
548548 tau_flat[iat * 3 + 1 ] = static_cast <FPTYPE>(ucell.atoms [it].tau [ia][1 ]);
@@ -555,7 +555,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
555555 gcar_flat[ig * 3 + 2 ] = static_cast <FPTYPE>(rho_basis->gcar [ig][2 ]);
556556 }
557557
558- // 重新计算vloc_factors考虑所有原子类型
558+ // calculate vloc_factors for all atom types
559559 std::vector<FPTYPE> vloc_per_type_host (ucell.ntype * rho_basis->npw );
560560 for (int iat = 0 ; iat < this ->nat ; iat++) {
561561 int it = ucell.iat2it [iat];
@@ -564,13 +564,11 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
564564 }
565565 }
566566
567- // 转换aux到FPTYPE类型
568567 std::vector<std::complex <FPTYPE>> aux_fptype (rho_basis->npw );
569568 for (int ig = 0 ; ig < rho_basis->npw ; ig++) {
570569 aux_fptype[ig] = static_cast <std::complex <FPTYPE>>(aux[ig]);
571570 }
572571
573- // 设备端内存和指针设置(根据设备类型分支)
574572 FPTYPE* d_gcar = gcar_flat.data ();
575573 FPTYPE* d_tau = tau_flat.data ();
576574 FPTYPE* d_vloc_per_type = vloc_per_type_host.data ();
@@ -591,26 +589,22 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
591589 resmem_complex_op ()(this ->ctx , d_aux, rho_basis->npw );
592590 resmem_var_op ()(this ->ctx , d_force, this ->nat * 3 );
593591
594- // 数据传输到设备
595592 syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , d_gcar, gcar_flat.data (), rho_basis->npw * 3 );
596593 syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , d_tau, tau_flat.data (), this ->nat * 3 );
597594 syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , d_vloc_per_type, vloc_per_type_host.data (), ucell.ntype * rho_basis->npw );
598595 syncmem_complex_h2d_op ()(this ->ctx , this ->cpu_ctx , d_aux, aux_fptype.data (), rho_basis->npw );
599596
600- // 初始化force为0
601597 base_device::memory::set_memory_op<FPTYPE, Device>()(this ->ctx , d_force, 0.0 , this ->nat * 3 );
602598 }
603599 else
604600 {
605601 d_force = force_host.data ();
606- // 对于CPU,直接初始化为0
607602 std::fill (force_host.begin (), force_host.end (), static_cast <FPTYPE>(0.0 ));
608603 }
609604
610- // 计算缩放因子
611605 const FPTYPE scale_factor = static_cast <FPTYPE>(ucell.tpiba * ucell.omega );
612606
613- // 调用op进行sincos计算
607+ // call op for sincos calculation
614608 hamilt::cal_force_loc_sincos_op<FPTYPE, Device>()(
615609 this ->ctx ,
616610 this ->nat ,
@@ -624,21 +618,17 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
624618 d_force
625619 );
626620
627- // 根据设备类型处理结果
628621 if (this ->device == base_device::GpuDevice)
629622 {
630- // 结果传回CPU
631623 syncmem_var_d2h_op ()(this ->cpu_ctx , this ->ctx , force_host.data (), d_force, this ->nat * 3 );
632624
633- // 清理设备内存
634625 delmem_var_op ()(this ->ctx , d_gcar);
635626 delmem_var_op ()(this ->ctx , d_tau);
636627 delmem_var_op ()(this ->ctx , d_vloc_per_type);
637628 delmem_complex_op ()(this ->ctx , d_aux);
638629 delmem_var_op ()(this ->ctx , d_force);
639630 }
640631
641- // 将结果写入forcelc矩阵
642632 for (int iat = 0 ; iat < this ->nat ; iat++) {
643633 forcelc (iat, 0 ) = static_cast <double >(force_host[iat * 3 + 0 ]);
644634 forcelc (iat, 1 ) = static_cast <double >(force_host[iat * 3 + 1 ]);
@@ -755,17 +745,15 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
755745 aux[rho_basis->ig_gge0 ] = std::complex <double >(0.0 , 0.0 );
756746 }
757747
758- // =============== 第一步:G空间sincos计算(在OpenMP区域外)===============
748+ // sincos op for cal_force_ew
759749
760- // 预计算每个原子的it_fact和相关数据:按照全局原子索引iat顺序
761750 std::vector<FPTYPE> it_facts_host (this ->nat );
762751 std::vector<FPTYPE> tau_flat (this ->nat * 3 );
763- std::vector<int > iat2it_host (this ->nat );
764752
765- // 按照原始代码逻辑:遍历全局原子索引iat,通过查表获取(it,ia)
753+ // iterate over by lookup table
766754 for (int iat = 0 ; iat < this ->nat ; iat++) {
767- int it = ucell.iat2it [iat]; // 查表获取原子类型
768- int ia = ucell.iat2ia [iat]; // 查表获取该类型内的原子索引
755+ int it = ucell.iat2it [iat];
756+ int ia = ucell.iat2ia [iat];
769757
770758 double zv;
771759 if (PARAM.inp .use_paw )
@@ -785,27 +773,22 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
785773 tau_flat[iat * 3 + 0 ] = static_cast <FPTYPE>(ucell.atoms [it].tau [ia][0 ]);
786774 tau_flat[iat * 3 + 1 ] = static_cast <FPTYPE>(ucell.atoms [it].tau [ia][1 ]);
787775 tau_flat[iat * 3 + 2 ] = static_cast <FPTYPE>(ucell.atoms [it].tau [ia][2 ]);
788- iat2it_host[iat] = it;
789776 }
790777
791- // 准备设备端数据
792778 std::vector<FPTYPE> gcar_flat (rho_basis->npw * 3 );
793779 for (int ig = 0 ; ig < rho_basis->npw ; ig++) {
794780 gcar_flat[ig * 3 + 0 ] = static_cast <FPTYPE>(rho_basis->gcar [ig][0 ]);
795781 gcar_flat[ig * 3 + 1 ] = static_cast <FPTYPE>(rho_basis->gcar [ig][1 ]);
796782 gcar_flat[ig * 3 + 2 ] = static_cast <FPTYPE>(rho_basis->gcar [ig][2 ]);
797783 }
798784
799- // 转换aux到FPTYPE类型
800785 std::vector<std::complex <FPTYPE>> aux_fptype (rho_basis->npw );
801786 for (int ig = 0 ; ig < rho_basis->npw ; ig++) {
802787 aux_fptype[ig] = static_cast <std::complex <FPTYPE>>(aux[ig]);
803788 }
804789
805- // 设备端内存和指针设置(根据设备类型分支)
806790 FPTYPE* d_gcar = gcar_flat.data ();
807791 FPTYPE* d_tau = tau_flat.data ();
808- int * d_iat2it = iat2it_host.data ();
809792 FPTYPE* d_it_facts = it_facts_host.data ();
810793 std::complex <FPTYPE>* d_aux = aux_fptype.data ();
811794 FPTYPE* d_force_g = nullptr ;
@@ -815,72 +798,66 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
815798 {
816799 d_gcar = nullptr ;
817800 d_tau = nullptr ;
818- d_iat2it = nullptr ;
819801 d_it_facts = nullptr ;
820802 d_aux = nullptr ;
821803
822804 resmem_var_op ()(this ->ctx , d_gcar, rho_basis->npw * 3 );
823805 resmem_var_op ()(this ->ctx , d_tau, this ->nat * 3 );
824- resmem_int_op ()(this ->ctx , d_iat2it, this ->nat );
825806 resmem_var_op ()(this ->ctx , d_it_facts, this ->nat );
826807 resmem_complex_op ()(this ->ctx , d_aux, rho_basis->npw );
827808 resmem_var_op ()(this ->ctx , d_force_g, this ->nat * 3 );
828809
829- // 数据传输
810+
830811 syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , d_gcar, gcar_flat.data (), rho_basis->npw * 3 );
831812 syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , d_tau, tau_flat.data (), this ->nat * 3 );
832- syncmem_int_h2d_op ()(this ->ctx , this ->cpu_ctx , d_iat2it, iat2it_host.data (), this ->nat );
833813 syncmem_var_h2d_op ()(this ->ctx , this ->cpu_ctx , d_it_facts, it_facts_host.data (), this ->nat );
834814 syncmem_complex_h2d_op ()(this ->ctx , this ->cpu_ctx , d_aux, aux_fptype.data (), rho_basis->npw );
835815
836- // 初始化force为0
816+
837817 base_device::memory::set_memory_op<FPTYPE, Device>()(this ->ctx , d_force_g, 0.0 , this ->nat * 3 );
838818 }
839819 else
840820 {
841821 d_force_g = force_g_host.data ();
842- // 对于CPU,直接初始化为0
843822 std::fill (force_g_host.begin (), force_g_host.end (), static_cast <FPTYPE>(0.0 ));
844823 }
845824
846- // 调用op处理G空间sincos计算(在OpenMP区域外,无冲突)
825+ // call op for sincos calculation
847826 hamilt::cal_force_ew_sincos_op<FPTYPE, Device>()(
848827 this ->ctx ,
849828 this ->nat ,
850829 rho_basis->npw ,
851- rho_basis->ig_gge0 , // G=0项索引,op内部会自动跳过
830+ rho_basis->ig_gge0 ,
852831 d_gcar,
853832 d_tau,
854- d_iat2it,
855833 d_it_facts,
856834 d_aux,
857835 d_force_g
858836 );
859837
860- // 根据设备类型处理结果
838+
861839 if (this ->device == base_device::GpuDevice)
862840 {
863- // 将G空间结果传回CPU
841+
864842 syncmem_var_d2h_op ()(this ->cpu_ctx , this ->ctx , force_g_host.data (), d_force_g, this ->nat * 3 );
865843
866- // 清理设备内存
844+
867845 delmem_var_op ()(this ->ctx , d_gcar);
868846 delmem_var_op ()(this ->ctx , d_tau);
869- delmem_int_op ()(this ->ctx , d_iat2it);
870847 delmem_var_op ()(this ->ctx , d_it_facts);
871848 delmem_complex_op ()(this ->ctx , d_aux);
872849 delmem_var_op ()(this ->ctx , d_force_g);
873850 }
874851
875- // 累加到forceion
852+
876853 for (int iat = 0 ; iat < this ->nat ; iat++) {
877854 forceion (iat, 0 ) += static_cast <double >(force_g_host[iat * 3 + 0 ]);
878855 forceion (iat, 1 ) += static_cast <double >(force_g_host[iat * 3 + 1 ]);
879856 forceion (iat, 2 ) += static_cast <double >(force_g_host[iat * 3 + 2 ]);
880857 }
881858
882- // =============== 第二步:实空间计算(保留原OpenMP结构)===============
883-
859+
860+ // calculate real space force
884861#ifdef _OPENMP
885862#pragma omp parallel
886863 {
@@ -904,7 +881,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
904881 iat_end = iat_beg + iat_end;
905882 ucell.iat2iait (iat_beg, &ia_beg, &it_beg);
906883
907- // 只保留实空间相互作用计算(means that the processor contains G=0 term)
884+
908885 if (rho_basis->ig_gge0 >= 0 )
909886 {
910887 double rmax = 5.0 / (sqrt (alpha) * ucell.lat0 );
0 commit comments