Skip to content

Commit 4aa1944

Browse files
authored
Merge pull request #90 from PeizeLin/develop
refactor opt_orb_pytorh_dpsi
2 parents fda0edd + 7aef2cf commit 4aa1944

File tree

3 files changed

+131
-77
lines changed

3 files changed

+131
-77
lines changed

tools/opt_orb_pytorch_dpsi/IO/change_info.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import addict
2+
import util
3+
import itertools
24

35
def change_info(info_old, weight_old):
46
info_stru = [None] * info_old.Nst
@@ -28,4 +30,47 @@ def change_info(info_old, weight_old):
2830
info_opt.cal_T = info_old.cal_T
2931
info_opt.cal_smooth = info_old.cal_smooth
3032

31-
return info_stru, info_element, info_opt
33+
return info_stru, info_element, info_opt
34+
35+
"""
36+
info_stru =
37+
[{'Na': {'C': 2},
38+
'Nb': 6,
39+
'weight': tensor([0.1250, 0.1250, 0.1150, 0.1150, 0.0200, 0.0000])},
40+
{'Na': {'C': 2},
41+
'Nb': 6,
42+
'weight': tensor([0.1250, 0.1250, 0.0896, 0.0896, 0.0707, 0.0000])}]
43+
44+
info_element =
45+
{'C': {'Ecut': 200,
46+
'Ne': 19,
47+
'Nl': 3,
48+
'Nu': [2, 2, 1],
49+
'Rcut': 6,
50+
'dr': 0.01}}
51+
52+
info_opt =
53+
{'cal_T': False,
54+
'cal_smooth': False,
55+
'lr': 0.01}
56+
"""
57+
58+
59+
def get_info_max(info_stru, info_element):
60+
info_max = [None] * len(info_stru)
61+
for ist in range(len(info_stru)):
62+
Nt = info_stru[ist].Na.keys()
63+
info_max[ist] = addict.Dict()
64+
info_max[ist].Nt = len(Nt)
65+
info_max[ist].Na = max((info_stru[ist].Na[it] for it in Nt))
66+
info_max[ist].Nl = max([info_element[it].Nl for it in Nt])
67+
info_max[ist].Nm = max((util.Nm(info_element[it].Nl-1) for it in Nt))
68+
info_max[ist].Nu = max(itertools.chain.from_iterable([info_element[it].Nu for it in Nt]))
69+
info_max[ist].Ne = max((info_element[it].Ne for it in Nt))
70+
info_max[ist].Nb = info_stru[ist].Nb
71+
return info_max
72+
73+
"""
74+
[{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2},
75+
{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}]
76+
"""
Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,133 @@
1-
from util import *
1+
import util
22
import torch
33
import itertools
44
import numpy as np
55
import re
66
import copy
77

8-
def read_file(info,file_list,V_info):
8+
def read_file_head(info,file_list):
99
""" QI[ist][it][il][ib*ia*im,ie] <\psi|jY> """
1010
""" SI[ist][it1][it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] <jY|jY> """
1111
""" VI[ist][ib] <\psi|\psi> """
1212
info_true = copy.deepcopy(info)
1313
info_true.Nst = len(file_list)
14-
info_true.Nt = ND_list(info_true.Nst,element="list()")
15-
info_true.Na = ND_list(info_true.Nst,element="dict()")
16-
info_true.Nb = ND_list(info_true.Nst)
17-
info_true.Nk = ND_list(info_true.Nst)
14+
info_true.Nt = util.ND_list(info_true.Nst,element="list()")
15+
info_true.Na = util.ND_list(info_true.Nst,element="dict()")
16+
info_true.Nb = util.ND_list(info_true.Nst)
17+
info_true.Nk = util.ND_list(info_true.Nst)
1818
info_true.Ne = dict()
19-
QI=[]; SI=[]; VI=[]
2019

2120
for ist_true,file_name in enumerate(file_list):
2221
print(file_name)
2322
with open(file_name,"r") as file:
2423

25-
ignore_line(file,4)
24+
util.ignore_line(file,4)
2625
Nt_tmp = int(file.readline().split()[0])
2726
for it in range(Nt_tmp):
2827
t_tmp = file.readline().split()[0]
2928
assert t_tmp in info.Nt_all
3029
info_true.Nt[ist_true].append( t_tmp )
3130
info_true.Na[ist_true][t_tmp] = int(file.readline().split()[0])
32-
ignore_line( file, info_true.Na[ist_true][t_tmp] )
33-
ignore_line(file,6)
31+
util.ignore_line( file, info_true.Na[ist_true][t_tmp] )
32+
util.ignore_line(file,6)
3433
Nl_ist = int(file.readline().split()[0])+1
3534
for it,Nl_C in info.Nl.items():
3635
print(it,Nl_ist,Nl_C)
3736
assert Nl_ist>=Nl_C
3837
info_true.Nl[it] = Nl_ist
3938
info_true.Nk[ist_true] = int(file.readline().split()[0])
4039
info_true.Nb[ist_true] = int(file.readline().split()[0])
41-
ignore_line(file,1)
42-
# Ne_tmp = list(map(int,file.readline().split()[:Nt_tmp]))
43-
# for it,Ne in zip(info_true.Nt[ist_true],Ne_tmp):
44-
# assert info_true.Ne.setdefault(it,Ne)==Ne
40+
util.ignore_line(file,1)
41+
#Ne_tmp = list(map(int,file.readline().split()[:Nt_tmp]))
42+
#for it,Ne in zip(info_true.Nt[ist_true],Ne_tmp):
43+
# assert info_true.Ne.setdefault(it,Ne)==Ne
4544
Ne_tmp = int(file.readline().split()[0])
4645
for it in info_true.Nt[ist_true]:
4746
info_true.Ne[it] = Ne_tmp
4847

48+
info_all = copy.deepcopy(info)
49+
info_all.Nst = sum(info_true.Nk,0)
50+
repeat_Nk = lambda x: list( itertools.chain.from_iterable( map( lambda x:itertools.repeat(*x), zip(x,info_true.Nk) ) ) )
51+
info_all.Nt = repeat_Nk(info_true.Nt)
52+
info_all.Na = repeat_Nk(info_true.Na)
53+
info_all.Nb = repeat_Nk(info_true.Nb)
54+
info_all.Ne = info_true.Ne
55+
56+
return info_all
57+
58+
59+
def read_QSV(info_stru, info_element, file_list, V_info):
60+
QI=[]; SI=[]; VI=[]
61+
ist = 0
4962
for ist_true,file_name in enumerate(file_list):
63+
with open(file_name,"r") as file:
64+
Nk = int(re.compile(r"(\d)+\s+nks").search(file.read()).group(1))
5065
with open(file_name,"r") as file:
5166
data = re.compile(r"<OVERLAP_Q>(.+)</OVERLAP_Q>", re.S).search(file.read())
5267
data = map(float,data.group(1).split())
53-
for ik in range(info_true.Nk[ist_true]):
68+
for ik in range(Nk):
5469
print("read QI:",ist_true,ik)
55-
qi = read_QI(info_true,ist_true,data)
70+
qi = read_QI(info_stru[ist+ik], info_element, data)
5671
QI.append( qi )
5772
with open(file_name,"r") as file:
5873
data = re.compile(r"<OVERLAP_Sq>(.+)</OVERLAP_Sq>", re.S).search(file.read())
5974
data = map(float,data.group(1).split())
60-
for ik in range(info_true.Nk[ist_true]):
75+
for ik in range(Nk):
6176
print("read SI:",ist_true,ik)
62-
si = read_SI(info_true,ist_true,data)
77+
si = read_SI(info_stru[ist+ik], info_element, data)
6378
SI.append( si )
6479
if V_info["init_from_file"]:
6580
with open(file_name,"r") as file:
6681
data = re.compile(r"<OVERLAP_V>(.+)</OVERLAP_V>", re.S).search(file.read())
6782
data = map(float,data.group(1).split())
6883
else:
6984
data = ()
70-
for ik in range(info_true.Nk[ist_true]):
85+
for ik in range(Nk):
7186
print("read VI:",ist_true,ik)
72-
vi = read_VI(info_true,V_info,ist_true,data)
87+
vi = read_VI(info_stru[ist+ik], V_info, ist_true, data)
7388
VI.append( vi )
89+
ist += Nk
7490
print()
75-
76-
info_all = copy.deepcopy(info)
77-
info_all.Nst = sum(info_true.Nk,0)
78-
repeat_Nk = lambda x: list( itertools.chain.from_iterable( map( lambda x:itertools.repeat(*x), zip(x,info_true.Nk) ) ) )
79-
info_all.Nt = repeat_Nk(info_true.Nt)
80-
info_all.Na = repeat_Nk(info_true.Na)
81-
info_all.Nb = repeat_Nk(info_true.Nb)
82-
info_all.Ne = info_true.Ne
83-
84-
return QI,SI,VI,info_all
91+
return QI,SI,VI
8592

8693

87-
88-
89-
def read_QI(info,ist,data):
94+
def read_QI(info_stru, info_element, data):
9095
""" QI[it][il][ib*ia*im,ie] <\psi|jY> """
9196
QI = dict()
92-
for it in info.Nt[ist]:
93-
QI[it] = ND_list(info.Nl[it])
94-
for il in range(info.Nl[it]):
95-
QI[it][il] = torch.zeros((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]), dtype=torch.complex128)
96-
for ib in range(info.Nb[ist]):
97-
for it in info.Nt[ist]:
98-
for ia in range(info.Na[ist][it]):
99-
for il in range(info.Nl[it]):
100-
for im in range(info.Nm(il)):
101-
for ie in range(info.Ne[it]):
97+
for it in info_stru.Na.keys():
98+
QI[it] = util.ND_list(info_element[it].Nl)
99+
for il in range(info_element[it].Nl):
100+
QI[it][il] = torch.zeros((info_stru.Nb, info_stru.Na[it], util.Nm(il), info_element[it].Ne), dtype=torch.complex128)
101+
for ib in range(info_stru.Nb):
102+
for it in info_stru.Na.keys():
103+
for ia in range(info_stru.Na[it]):
104+
for il in range(info_element[it].Nl):
105+
for im in range(util.Nm(il)):
106+
for ie in range(info_element[it].Ne):
102107
QI[it][il][ib,ia,im,ie] = complex(next(data), next(data))
103-
for it in info.Nt[ist]:
104-
for il in range(info.Nl[it]):
105-
QI[it][il] = QI[it][il].view(-1,info.Ne[it]).conj()
108+
for it in info_stru.Na.keys():
109+
for il in range(info_element[it].Nl):
110+
QI[it][il] = QI[it][il].view(-1,info_element[it].Ne).conj()
106111
return QI
107112

108113

109-
110-
def read_SI(info,ist,data):
114+
def read_SI(info_stru, info_element, data):
111115
""" SI[it1,it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] <jY|jY> """
112116
SI = dict()
113-
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
114-
SI[it1,it2] = ND_list(info.Nl[it1],info.Nl[it2])
115-
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
116-
SI[it1,it2][il1][il2] = torch.zeros((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]), dtype=torch.complex128)
117-
for it1 in info.Nt[ist]:
118-
for ia1 in range(info.Na[ist][it1]):
119-
for il1 in range(info.Nl[it1]):
120-
for im1 in range(info.Nm(il1)):
121-
for it2 in info.Nt[ist]:
122-
for ia2 in range(info.Na[ist][it2]):
123-
for il2 in range(info.Nl[it2]):
124-
for im2 in range(info.Nm(il2)):
125-
for ie1 in range(info.Ne[it1]):
126-
for ie2 in range(info.Ne[it2]):
117+
for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ):
118+
SI[it1,it2] = util.ND_list(info_element[it1].Nl, info_element[it2].Nl)
119+
for il1,il2 in itertools.product( range(info_element[it1].Nl), range(info_element[it2].Nl) ):
120+
SI[it1,it2][il1][il2] = torch.zeros((info_stru.Na[it1], util.Nm(il1), info_element[it1].Ne, info_stru.Na[it2], util.Nm(il2), info_element[it2].Ne), dtype=torch.complex128)
121+
for it1 in info_stru.Na.keys():
122+
for ia1 in range(info_stru.Na[it1]):
123+
for il1 in range(info_element[it1].Nl):
124+
for im1 in range(util.Nm(il1)):
125+
for it2 in info_stru.Na.keys():
126+
for ia2 in range(info_stru.Na[it2]):
127+
for il2 in range(info_element[it2].Nl):
128+
for im2 in range(util.Nm(il2)):
129+
for ie1 in range(info_element[it1].Ne):
130+
for ie2 in range(info_element[it2].Ne):
127131
SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] = complex(next(data), next(data))
128132
# for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
129133
# for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
@@ -134,21 +138,21 @@ def read_SI(info,ist,data):
134138

135139

136140

137-
def read_VI(info,V_info,ist,data):
141+
def read_VI(info_stru,V_info,ist,data):
138142
if V_info["same_band"]:
139143
""" VI[ib] <psi|psi> """
140144
if V_info["init_from_file"]:
141-
VI = np.empty(info.Nb[ist],dtype=np.float64)
142-
for ib in range(info.Nb[ist]):
145+
VI = np.empty(info_stru.Nb,dtype=np.float64)
146+
for ib in range(info_stru.Nb):
143147
VI.data[ib] = next(data)
144148
else:
145-
VI = np.ones(info.Nb[ist],dtype=np.float64)
149+
VI = np.ones(info_stru.Nb,dtype=np.float64)
146150
else:
147151
""" VI[ib1,ib2] <psi|psi> """
148152
if V_info["init_from_file"]:
149-
VI = np.empty((info.Nb[ist],info.Nb[ist]),dtype=np.float64)
150-
for ib1,ib2 in itertools.product( range(info.Nb[ist]), range(info.Nb[ist]) ):
153+
VI = np.empty((info_stru.Nb,info_stru.Nb),dtype=np.float64)
154+
for ib1,ib2 in itertools.product( range(info_stru.Nb), range(info_stru.Nb) ):
151155
VI[ib1,ib2] = next(data)
152156
else:
153-
VI = np.eye(info.Nb[ist],info.Nb[ist],dtype=np.float64)
157+
VI = np.eye(info_stru.Nb,info_stru.Nb,dtype=np.float64)
154158
return torch.from_numpy(VI)

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,20 @@ def main():
2626

2727
weight = IO.cal_weight.cal_weight(weight_info, V_info["same_band"], file_list["origin"])
2828

29-
QI,SI,VI_origin,info_kst = IO.read_QSV.read_file(info_true,file_list["origin"],V_info)
30-
if "linear" in file_list.keys():
31-
QI_linear, SI_linear, VI_linear, info_linear = list(zip(*( IO.read_QSV.read_file(info_true,file,V_info) for file in file_list["linear"] )))
29+
info_kst = IO.read_QSV.read_file_head(info_true,file_list["origin"])
3230

3331
info_stru, info_element, info_opt = IO.change_info.change_info(info_kst,weight)
32+
info_max = IO.change_info.get_info_max(info_stru, info_element)
33+
34+
print("info_kst:", info_kst, sep="\n", end="\n"*2, flush=True)
35+
print("info_stru:", pprint.pformat(info_stru), sep="\n", end="\n"*2, flush=True)
36+
print("info_element:", pprint.pformat(info_element,width=40), sep="\n", end="\n"*2, flush=True)
37+
print("info_opt:", pprint.pformat(info_opt,width=40), sep="\n", end="\n"*2, flush=True)
38+
print("info_max:", pprint.pformat(info_max), sep="\n", end="\n"*2, flush=True)
3439

35-
print(pprint.pformat(info_stru), end="\n"*2, flush=True)
36-
print(pprint.pformat(info_element,width=40), end="\n"*2, flush=True)
37-
print(pprint.pformat(info_opt,width=40), end="\n"*2, flush=True)
40+
QI,SI,VI_origin = IO.read_QSV.read_QSV(info_stru, info_element, file_list["origin"], V_info)
41+
if "linear" in file_list.keys():
42+
QI_linear, SI_linear, VI_linear = list(zip(*( IO.read_QSV.read_QSV(info_stru, info_element, file, V_info) for file in file_list["linear"] )))
3843

3944
if C_init_info["init_from_file"]:
4045
C, C_read_index = IO.func_C.read_C_init( C_init_info["C_init_file"], info_element )
@@ -63,7 +68,7 @@ def main():
6368
print( '%5s'%"istep", "%20s"%"Spillage", flush=True )
6469

6570
loss_old = np.inf
66-
for istep in range(3):
71+
for istep in range(200):
6772

6873
Spillage = 0
6974
for ist in range(len(info_stru)):

0 commit comments

Comments
 (0)