1313import torch_optimizer
1414import IO .cal_weight
1515import util
16+ import IO .change_info
17+ import pprint
1618
1719def main ():
1820 seed = int (1000 * time .time ())% (2 ** 32 )
@@ -24,51 +26,59 @@ def main():
2426
2527 weight = IO .cal_weight .cal_weight (weight_info , V_info ["same_band" ], file_list ["origin" ])
2628
27- QI ,SI ,VI_origin ,info = IO .read_QSV .read_file (info_true ,file_list ["origin" ],V_info )
28- print (info , flush = True )
29+ QI ,SI ,VI_origin ,info_kst = IO .read_QSV .read_file (info_true ,file_list ["origin" ],V_info )
2930 if "linear" in file_list .keys ():
3031 QI_linear , SI_linear , VI_linear , info_linear = list (zip (* ( IO .read_QSV .read_file (info_true ,file ,V_info ) for file in file_list ["linear" ] )))
3132
33+ info_stru , info_element , info_opt = IO .change_info .change_info (info_kst ,weight )
34+
35+ print (pprint .pformat (info_stru ), end = "\n " * 2 , flush = True )
36+ print (pprint .pformat (info_element ,width = 40 ), end = "\n " * 2 , flush = True )
37+ print (pprint .pformat (info_opt ,width = 40 ), end = "\n " * 2 , flush = True )
38+
3239 if C_init_info ["init_from_file" ]:
33- C , C_read_index = IO .func_C .read_C_init ( C_init_info ["C_init_file" ], info )
40+ C , C_read_index = IO .func_C .read_C_init ( C_init_info ["C_init_file" ], info_element )
3441 else :
35- C = IO .func_C .random_C_init (info )
36- E = orbital .set_E (info ,info .Rcut )
37- orbital .normalize ( orbital .generate_orbital (info ,C ,E ,info .Rcut ,info .dr ), info .dr ,C ,flag_norm_C = True )
42+ C = IO .func_C .random_C_init (info_element )
43+ E = orbital .set_E (info_element )
44+ orbital .normalize (
45+ orbital .generate_orbital (info_element ,C ,E ),
46+ {it :info_element [it ].dr for it in info_element },
47+ C , flag_norm_C = True )
3848
3949 opt_orb = opt_orbital .Opt_Orbital ()
4050
41- #opt = torch.optim.Adam(sum( ([c.real,c.imag] for c in sum(C,[])), []), lr=info .lr, eps=1e-8)
42- #opt = torch.optim.Adam( sum(C.values(),[]), lr=info .lr, eps=1e-20, weight_decay=info .weight_decay)
43- #opt = radam.RAdam( sum(C.values(),[]), lr=info .lr, eps=1e-20 )
44- opt = torch_optimizer .SWATS ( sum (C .values (),[]), lr = info .lr , eps = 1e-20 )
51+ #opt = torch.optim.Adam(sum( ([c.real,c.imag] for c in sum(C,[])), []), lr=info_opt .lr, eps=1e-8)
52+ #opt = torch.optim.Adam( sum(C.values(),[]), lr=info_opt .lr, eps=1e-20, weight_decay=info_opt .weight_decay)
53+ #opt = radam.RAdam( sum(C.values(),[]), lr=info_opt .lr, eps=1e-20 )
54+ opt = torch_optimizer .SWATS ( sum (C .values (),[]), lr = info_opt .lr , eps = 1e-20 )
4555
4656
4757 with open ("Spillage.dat" ,"w" ) as S_file :
4858
4959 print ( "\n See \" Spillage.dat\" for detail status: " , flush = True )
50- if info .cal_T :
60+ if info_opt .cal_T :
5161 print ( '%5s' % "istep" , "%20s" % "Spillage" , "%20s" % "T.item()" , "%20s" % "Loss" , flush = True )
5262 else :
5363 print ( '%5s' % "istep" , "%20s" % "Spillage" , flush = True )
5464
5565 loss_old = np .inf
56- for istep in range (200 ):
66+ for istep in range (3 ):
5767
5868 Spillage = 0
59- for ist in range (info . Nst ):
69+ for ist in range (len ( info_stru ) ):
6070
61- Q = opt_orb .change_index_Q (opt_orb .cal_Q (QI [ist ],C ,info , ist ), info , ist )
62- S = opt_orb .change_index_S (opt_orb .cal_S (SI [ist ],C ,info , ist ), info , ist )
71+ Q = opt_orb .change_index_Q (opt_orb .cal_Q (QI [ist ],C ,info_stru [ ist ], info_element ), info_stru [ ist ] )
72+ S = opt_orb .change_index_S (opt_orb .cal_S (SI [ist ],C ,info_stru [ ist ], info_element ), info_stru [ ist ], info_element )
6373 coef = opt_orb .cal_coef (Q ,S )
6474 V = opt_orb .cal_V (coef ,Q )
6575 V_origin = opt_orb .cal_V_origin (V ,V_info )
6676
6777 if "linear" in file_list .keys ():
6878 V_linear = [None ] * len (file_list ["linear" ])
6979 for i in range (len (file_list ["linear" ])):
70- Q_linear = opt_orb .change_index_Q (opt_orb .cal_Q (QI_linear [i ][ist ],C ,info , ist ), info , ist )
71- S_linear = opt_orb .change_index_S (opt_orb .cal_S (SI_linear [i ][ist ],C ,info , ist ), info , ist )
80+ Q_linear = opt_orb .change_index_Q (opt_orb .cal_Q (QI_linear [i ][ist ],C ,info_stru [ ist ], info_element ), info_stru [ ist ] )
81+ S_linear = opt_orb .change_index_S (opt_orb .cal_S (SI_linear [i ][ist ],C ,info_stru [ ist ], info_element ), info_stru [ ist ], info_element )
7282 V_linear [i ] = opt_orb .cal_V_linear (coef ,Q_linear ,S_linear ,V ,V_info )
7383
7484 def cal_Spillage (V_delta ):
@@ -83,14 +93,14 @@ def cal_delta(VI, V):
8393 for i in range (len (file_list ["linear" ])):
8494 Spillage += cal_Spillage (cal_delta (VI_linear [i ],V_linear [i ]))
8595
86- if info .cal_T :
96+ if info_opt .cal_T :
8797 T = opt_orb .cal_T (C ,E )
8898 if not "TSrate" in vars (): TSrate = torch .abs (0.002 * Spillage / T ).data [0 ]
8999 Loss = Spillage + TSrate * T
90100 else :
91101 Loss = Spillage
92102
93- if info .cal_T :
103+ if info_opt .cal_T :
94104 print_content = [istep , Spillage .item (), T .item (), Loss .item ()]
95105 else :
96106 print_content = [istep , Spillage .item ()]
@@ -100,7 +110,7 @@ def cal_delta(VI, V):
100110
101111 if Loss .item () < loss_old :
102112 loss_old = Loss .item ()
103- C_old = IO .func_C .copy_C (C ,info )
113+ C_old = IO .func_C .copy_C (C ,info_element )
104114 flag_finish = 0
105115 else :
106116 flag_finish += 1
@@ -113,16 +123,27 @@ def cal_delta(VI, V):
113123 for it ,il ,iu in C_read_index :
114124 C [it ][il ].grad [:,iu ] = 0
115125 opt .step ()
116- # orbital.normalize( orbital.generate_orbital(info,C,E,info.Rcut,info.dr), info.dr,C,flag_norm_C=True)
117-
118- orb = orbital .generate_orbital (info ,C_old ,E ,info .Rcut ,info .dr )
119- if info .cal_smooth :
120- orbital .smooth_orbital (orb ,info .Rcut ,info .dr ,0.1 )
121- orbital .orth (orb ,info .dr )
122- IO .print_orbital .print_orbital (orb ,info )
123- IO .print_orbital .plot_orbital (orb ,info .Rcut ,info .dr )
124-
125- IO .func_C .write_C ("ORBITAL_RESULTS.txt" ,info ,C_old ,Spillage )
126+ # orbital.normalize(
127+ # orbital.generate_orbital(info_element,C,E),
128+ # {it:info_element[it].dr for it in info_element},
129+ # C, flag_norm_C=True)
130+
131+ orb = orbital .generate_orbital (info_element ,C_old ,E )
132+ if info_opt .cal_smooth :
133+ orbital .smooth_orbital (
134+ orb ,
135+ {it :info_element [it ].Rcut for it in info_element }, {it :info_element [it ].dr for it in info_element },
136+ 0.1 )
137+ orbital .orth (
138+ orb ,
139+ {it :info_element [it ].dr for it in info_element })
140+ IO .print_orbital .print_orbital (orb ,info_element )
141+ IO .print_orbital .plot_orbital (
142+ orb ,
143+ {it :info_element [it ].Rcut for it in info_element },
144+ {it :info_element [it ].dr for it in info_element })
145+
146+ IO .func_C .write_C ("ORBITAL_RESULTS.txt" ,C_old ,Spillage )
126147
127148 print ("Time (PyTorch): %s\n " % (time .time ()- time_start ), flush = True )
128149
0 commit comments