Skip to content

Commit 97faf88

Browse files
committed
1. In opt_orb_pytorch_dpsi, split info to info_stru, info_element, info_opt
1 parent efd30d4 commit 97faf88

File tree

8 files changed

+149
-92
lines changed

8 files changed

+149
-92
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import addict
2+
3+
def change_info(info_old, weight_old):
4+
info_stru = [None] * info_old.Nst
5+
for ist in range(len(info_stru)):
6+
info_stru[ist] = addict.Dict()
7+
for ist,Na in enumerate(info_old.Na):
8+
info_stru[ist].Na = Na
9+
for ist,weight in enumerate(weight_old):
10+
info_stru[ist].weight = weight
11+
info_stru[ist].Nb = weight.shape[0]
12+
13+
info_element = addict.Dict()
14+
for it,Nu in info_old.Nu.items():
15+
info_element[it].Nu = Nu
16+
info_element[it].Nl = len(Nu)
17+
for it,Rcut in info_old.Rcut.items():
18+
info_element[it].Rcut = Rcut
19+
for it,dr in info_old.dr.items():
20+
info_element[it].dr = dr
21+
for it,Ecut in info_old.Ecut.items():
22+
info_element[it].Ecut = Ecut
23+
for it,Ne in info_old.Ne.items():
24+
info_element[it].Ne = Ne
25+
26+
info_opt = addict.Dict()
27+
info_opt.lr = info_old.lr
28+
info_opt.cal_T = info_old.cal_T
29+
info_opt.cal_smooth = info_old.cal_smooth
30+
31+
return info_stru, info_element, info_opt

tools/opt_orb_pytorch_dpsi/IO/func_C.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
import torch
33
import numpy as np
44

5-
def random_C_init(info):
5+
def random_C_init(info_element):
66
""" C[it][il][ie,iu] <jY|\phi> """
77
C = dict()
8-
for it in info.Nt_all:
9-
C[it] = ND_list(info.Nl[it])
10-
for il in range(info.Nl[it]):
11-
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info.Ne[it], info.Nu[it][il])), dtype=torch.float64, requires_grad=True)
8+
for it in info_element.keys():
9+
C[it] = ND_list(info_element[it].Nl)
10+
for il in range(info_element[it].Nl):
11+
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info_element[it].Ne, info_element[it].Nu[il])), dtype=torch.float64, requires_grad=True)
1212
return C
1313

1414

1515

16-
def read_C_init(file_name,info):
16+
def read_C_init(file_name,info_element):
1717
""" C[it][il][ie,iu] <jY|\phi> """
18-
C = random_C_init(info)
18+
C = random_C_init(info_element)
1919

2020
with open(file_name,"r") as file:
2121

@@ -29,11 +29,12 @@ def read_C_init(file_name,info):
2929
while True:
3030
line = file.readline().strip()
3131
if line.startswith("Type"):
32-
it,il,iu = list(map(int,file.readline().split()));
33-
it=info.Nt_all[it-1]; iu-=1
32+
it,il,iu = file.readline().split();
33+
il = int(il)
34+
iu = int(iu)-1
3435
C_read_index.add((it,il,iu))
3536
line = file.readline().split()
36-
for ie in range(info.Ne[it]):
37+
for ie in range(info_element[it].Ne):
3738
if not line: line = file.readline().split()
3839
C[it][il].data[ie,iu] = float(line.pop(0))
3940
elif line.startswith("</Coefficient>"):
@@ -44,17 +45,17 @@ def read_C_init(file_name,info):
4445

4546

4647

47-
def copy_C(C,info):
48+
def copy_C(C,info_element):
4849
C_copy = dict()
49-
for it in info.Nt_all:
50-
C_copy[it] = ND_list(info.Nl[it])
51-
for il in range(info.Nl[it]):
50+
for it in info_element.keys():
51+
C_copy[it] = ND_list(info_element[it].Nl)
52+
for il in range(info_element[it].Nl):
5253
C_copy[it][il] = C[it][il].clone()
5354
return C_copy
5455

5556

5657

57-
def write_C(file_name,info,C,Spillage):
58+
def write_C(file_name,C,Spillage):
5859
with open(file_name,"w") as file:
5960
print("<Coefficient>", file=file)
6061
#print("\tTotal number of radial orbitals.", file=file)
@@ -70,7 +71,7 @@ def write_C(file_name,info,C,Spillage):
7071
for il,C_tl in enumerate(C_t):
7172
for iu in range(C_tl.size()[1]):
7273
print("\tType\tL\tZeta-Orbital", file=file)
73-
print(f"\t {info.Nt_all.index(it)+1} \t{il}\t {iu+1}", file=file)
74+
print(f"\t {it} \t{il}\t {iu+1}", file=file)
7475
for ie in range(C_tl.size()[0]):
7576
print("\t", '%18.14f'%C_tl[ie,iu].item(), file=file)
7677
print("</Coefficient>", file=file)

tools/opt_orb_pytorch_dpsi/IO/print_orbital.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
## 'Fl': 114, 'Uup': 115, 'Lv': 116, 'Uus': 117, 'Uuo': 118
2424
}
2525

26-
def print_orbital(orb,info):
26+
def print_orbital(orb,info_element):
2727
""" orb[it][il][iu][r] """
2828
for it,orb_t in orb.items():
2929
#with open("orb_{0}.dat".format(it),"w") as file:
3030
with open("ORBITAL_{0}U.dat".format( periodtable[it] ),"w") as file:
31-
print_orbital_head(file,info,it)
31+
print_orbital_head(file,info_element,it)
3232
for il,orb_tl in enumerate(orb_t):
3333
for iu,orb_tlu in enumerate(orb_tl):
3434
print(""" Type L N""",file=file)
@@ -52,19 +52,19 @@ def plot_orbital(orb,Rcut,dr):
5252
print(file=file)
5353

5454

55-
def print_orbital_head(file,info,it):
55+
def print_orbital_head(file,info_element,it):
5656
print( "---------------------------------------------------------------------------", file=file )
5757
print( "Element {0}".format(it), file=file )
58-
print( "Energy Cutoff(Ry) {0}".format(info.Ecut[it]), file=file )
59-
print( "Radius Cutoff(a.u.) {0}".format(info.Rcut[it]), file=file )
60-
print( "Lmax {0}".format(info.Nl[it]-1), file=file )
58+
print( "Energy Cutoff(Ry) {0}".format(info_element[it].Ecut), file=file )
59+
print( "Radius Cutoff(a.u.) {0}".format(info_element[it].Rcut), file=file )
60+
print( "Lmax {0}".format(info_element[it].Nl-1), file=file )
6161
l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1)))
62-
for il,iu in enumerate(info.Nu[it]):
62+
for il,iu in enumerate(info_element[it].Nu):
6363
print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file )
6464
print( "---------------------------------------------------------------------------", file=file )
6565
print( "SUMMARY END", file=file )
6666
print( file=file )
67-
print( "Mesh {0}".format(int(info.Rcut[it]/info.dr[it])+1), file=file )
68-
print( "dr {0}".format(info.dr[it]), file=file )
67+
print( "Mesh {0}".format(int(info_element[it].Rcut/info_element[it].dr)+1), file=file )
68+
print( "dr {0}".format(info_element[it].dr), file=file )
6969

7070

tools/opt_orb_pytorch_dpsi/IO/read_QSV.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def read_file(info,file_list,V_info):
7171
print("read VI:",ist_true,ik)
7272
vi = read_VI(info_true,V_info,ist_true,data)
7373
VI.append( vi )
74+
print()
7475

7576
info_all = copy.deepcopy(info)
7677
info_all.Nst = sum(info_true.Nk,0)

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch_optimizer
1414
import IO.cal_weight
1515
import util
16+
import IO.change_info
17+
import pprint
1618

1719
def main():
1820
seed = int(1000*time.time())%(2**32)
@@ -24,51 +26,59 @@ def main():
2426

2527
weight = IO.cal_weight.cal_weight(weight_info, V_info["same_band"], file_list["origin"])
2628

27-
QI,SI,VI_origin,info = IO.read_QSV.read_file(info_true,file_list["origin"],V_info)
28-
print(info, flush=True)
29+
QI,SI,VI_origin,info_kst = IO.read_QSV.read_file(info_true,file_list["origin"],V_info)
2930
if "linear" in file_list.keys():
3031
QI_linear, SI_linear, VI_linear, info_linear = list(zip(*( IO.read_QSV.read_file(info_true,file,V_info) for file in file_list["linear"] )))
3132

33+
info_stru, info_element, info_opt = IO.change_info.change_info(info_kst,weight)
34+
35+
print(pprint.pformat(info_stru), end="\n"*2, flush=True)
36+
print(pprint.pformat(info_element,width=40), end="\n"*2, flush=True)
37+
print(pprint.pformat(info_opt,width=40), end="\n"*2, flush=True)
38+
3239
if C_init_info["init_from_file"]:
33-
C, C_read_index = IO.func_C.read_C_init( C_init_info["C_init_file"], info )
40+
C, C_read_index = IO.func_C.read_C_init( C_init_info["C_init_file"], info_element )
3441
else:
35-
C = IO.func_C.random_C_init(info)
36-
E = orbital.set_E(info,info.Rcut)
37-
orbital.normalize( orbital.generate_orbital(info,C,E,info.Rcut,info.dr), info.dr,C,flag_norm_C=True)
42+
C = IO.func_C.random_C_init(info_element)
43+
E = orbital.set_E(info_element)
44+
orbital.normalize(
45+
orbital.generate_orbital(info_element,C,E),
46+
{it:info_element[it].dr for it in info_element},
47+
C, flag_norm_C=True)
3848

3949
opt_orb = opt_orbital.Opt_Orbital()
4050

41-
#opt = torch.optim.Adam(sum( ([c.real,c.imag] for c in sum(C,[])), []), lr=info.lr, eps=1e-8)
42-
#opt = torch.optim.Adam( sum(C.values(),[]), lr=info.lr, eps=1e-20, weight_decay=info.weight_decay)
43-
#opt = radam.RAdam( sum(C.values(),[]), lr=info.lr, eps=1e-20 )
44-
opt = torch_optimizer.SWATS( sum(C.values(),[]), lr=info.lr, eps=1e-20 )
51+
#opt = torch.optim.Adam(sum( ([c.real,c.imag] for c in sum(C,[])), []), lr=info_opt.lr, eps=1e-8)
52+
#opt = torch.optim.Adam( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20, weight_decay=info_opt.weight_decay)
53+
#opt = radam.RAdam( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20 )
54+
opt = torch_optimizer.SWATS( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20 )
4555

4656

4757
with open("Spillage.dat","w") as S_file:
4858

4959
print( "\nSee \"Spillage.dat\" for detail status: " , flush=True )
50-
if info.cal_T:
60+
if info_opt.cal_T:
5161
print( '%5s'%"istep", "%20s"%"Spillage", "%20s"%"T.item()", "%20s"%"Loss", flush=True )
5262
else:
5363
print( '%5s'%"istep", "%20s"%"Spillage", flush=True )
5464

5565
loss_old = np.inf
56-
for istep in range(200):
66+
for istep in range(3):
5767

5868
Spillage = 0
59-
for ist in range(info.Nst):
69+
for ist in range(len(info_stru)):
6070

61-
Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info,ist),info,ist)
62-
S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info,ist),info,ist)
71+
Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info_stru[ist],info_element),info_stru[ist])
72+
S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info_stru[ist],info_element),info_stru[ist],info_element)
6373
coef = opt_orb.cal_coef(Q,S)
6474
V = opt_orb.cal_V(coef,Q)
6575
V_origin = opt_orb.cal_V_origin(V,V_info)
6676

6777
if "linear" in file_list.keys():
6878
V_linear = [None] * len(file_list["linear"])
6979
for i in range(len(file_list["linear"])):
70-
Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i][ist],C,info,ist),info,ist)
71-
S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i][ist],C,info,ist),info,ist)
80+
Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i][ist],C,info_stru[ist],info_element),info_stru[ist])
81+
S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i][ist],C,info_stru[ist],info_element),info_stru[ist],info_element)
7282
V_linear[i] = opt_orb.cal_V_linear(coef,Q_linear,S_linear,V,V_info)
7383

7484
def cal_Spillage(V_delta):
@@ -83,14 +93,14 @@ def cal_delta(VI, V):
8393
for i in range(len(file_list["linear"])):
8494
Spillage += cal_Spillage(cal_delta(VI_linear[i],V_linear[i]))
8595

86-
if info.cal_T:
96+
if info_opt.cal_T:
8797
T = opt_orb.cal_T(C,E)
8898
if not "TSrate" in vars(): TSrate = torch.abs(0.002*Spillage/T).data[0]
8999
Loss = Spillage + TSrate*T
90100
else:
91101
Loss = Spillage
92102

93-
if info.cal_T:
103+
if info_opt.cal_T:
94104
print_content = [istep, Spillage.item(), T.item(), Loss.item()]
95105
else:
96106
print_content = [istep, Spillage.item()]
@@ -100,7 +110,7 @@ def cal_delta(VI, V):
100110

101111
if Loss.item() < loss_old:
102112
loss_old = Loss.item()
103-
C_old = IO.func_C.copy_C(C,info)
113+
C_old = IO.func_C.copy_C(C,info_element)
104114
flag_finish = 0
105115
else:
106116
flag_finish += 1
@@ -113,16 +123,27 @@ def cal_delta(VI, V):
113123
for it,il,iu in C_read_index:
114124
C[it][il].grad[:,iu] = 0
115125
opt.step()
116-
# orbital.normalize( orbital.generate_orbital(info,C,E,info.Rcut,info.dr), info.dr,C,flag_norm_C=True)
117-
118-
orb = orbital.generate_orbital(info,C_old,E,info.Rcut,info.dr)
119-
if info.cal_smooth:
120-
orbital.smooth_orbital(orb,info.Rcut,info.dr,0.1)
121-
orbital.orth(orb,info.dr)
122-
IO.print_orbital.print_orbital(orb,info)
123-
IO.print_orbital.plot_orbital(orb,info.Rcut,info.dr)
124-
125-
IO.func_C.write_C("ORBITAL_RESULTS.txt",info,C_old,Spillage)
126+
# orbital.normalize(
127+
# orbital.generate_orbital(info_element,C,E),
128+
# {it:info_element[it].dr for it in info_element},
129+
# C, flag_norm_C=True)
130+
131+
orb = orbital.generate_orbital(info_element,C_old,E)
132+
if info_opt.cal_smooth:
133+
orbital.smooth_orbital(
134+
orb,
135+
{it:info_element[it].Rcut for it in info_element}, {it:info_element[it].dr for it in info_element},
136+
0.1)
137+
orbital.orth(
138+
orb,
139+
{it:info_element[it].dr for it in info_element})
140+
IO.print_orbital.print_orbital(orb,info_element)
141+
IO.print_orbital.plot_orbital(
142+
orb,
143+
{it:info_element[it].Rcut for it in info_element},
144+
{it:info_element[it].dr for it in info_element})
145+
146+
IO.func_C.write_C("ORBITAL_RESULTS.txt",C_old,Spillage)
126147

127148
print("Time (PyTorch): %s\n"%(time.time()-time_start), flush=True )
128149

0 commit comments

Comments
 (0)