Skip to content

Commit 922be2c

Browse files
committed
1. in opt_orb_pytorch_dpsi, delete ist in class Opt_Orbital
1 parent e7c1dc1 commit 922be2c

File tree

2 files changed

+122
-135
lines changed

2 files changed

+122
-135
lines changed

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,34 +50,36 @@ def main():
5050
if info.cal_T:
5151
print( '%5s'%"istep", "%20s"%"Spillage", "%20s"%"T.item()", "%20s"%"Loss", flush=True )
5252
else:
53-
print( '%5s'%"istep", "%20s"%"Spillage", flush=True )
53+
print( '%5s'%"istep", "%20s"%"Spillage", flush=True )
54+
5455
loss_old = np.inf
5556
for istep in range(200):
5657

57-
Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI,C,info),info)
58-
S = opt_orb.change_index_S(opt_orb.cal_S(SI,C,info),info)
59-
V = opt_orb.cal_V(Q,S,info,V_info)
60-
61-
if "linear" in file_list.keys():
62-
V_linear = [None] * len(file_list["linear"])
63-
for i in range(len(file_list["linear"])):
64-
Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i],C,info),info)
65-
S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i],C,info),info)
66-
V_linear[i] = opt_orb.cal_V_linear(Q,S,Q_linear,S_linear,V,info,V_info)
67-
68-
def cal_Spillage(V_delta):
69-
Spillage = torch.Tensor([0])
70-
for ist, Vi_delta in enumerate(V_delta):
71-
Spillage += (Vi_delta * weight[ist]).sum()
72-
return Spillage
73-
74-
def cal_delta(VI, V):
75-
return ( ((VIi-Vi)/util.update0(VIi)).abs() for VIi,Vi in zip(VI,V) ) # abs or **2?
76-
77-
Spillage = 2*cal_Spillage(cal_delta(VI,V))
78-
if "linear" in file_list.keys():
79-
for i in range(len(file_list["linear"])):
80-
Spillage += cal_Spillage(cal_delta(VI_linear[i],V_linear[i]))
58+
Spillage = 0
59+
for ist in range(info.Nst):
60+
61+
Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info,ist),info,ist)
62+
S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info,ist),info,ist)
63+
V = opt_orb.cal_V(Q,S,V_info)
64+
65+
if "linear" in file_list.keys():
66+
V_linear = [None] * len(file_list["linear"])
67+
for i in range(len(file_list["linear"])):
68+
Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i][ist],C,info,ist),info,ist)
69+
S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i][ist],C,info,ist),info,ist)
70+
V_linear[i] = opt_orb.cal_V_linear(Q,S,Q_linear,S_linear,V,V_info)
71+
72+
def cal_Spillage(V_delta):
73+
Spillage = (V_delta * weight[ist]).sum()
74+
return Spillage
75+
76+
def cal_delta(VI, V):
77+
return ((VI[ist]-V)/util.update0(VI[ist])).abs() # abs or **2?
78+
79+
Spillage += 2*cal_Spillage(cal_delta(VI,V))
80+
if "linear" in file_list.keys():
81+
for i in range(len(file_list["linear"])):
82+
Spillage += cal_Spillage(cal_delta(VI_linear[i],V_linear[i]))
8183

8284
if info.cal_T:
8385
T = opt_orb.cal_T(C,E)

tools/opt_orb_pytorch_dpsi/opt_orbital.py

Lines changed: 95 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -7,165 +7,150 @@
77

88
class Opt_Orbital:
99

10-
def cal_Q(self,QI,C,info):
10+
def cal_Q(self,QI,C,info,ist):
1111
"""
1212
<\psi|\phi> = <\psi|jY> * <jY|\phi>
13-
Q[ist][it][il][ib,ia*im*iu]
14-
= sum_{q} QI[ist][it][il][ib*ia*im,ie] * C[it][il][ie,iu]
13+
Q[it][il][ib,ia*im*iu]
14+
= sum_{q} QI[it][il][ib*ia*im,ie] * C[it][il][ie,iu]
1515
"""
16-
Q = ND_list(info.Nst,element="dict()")
17-
for ist in range(info.Nst):
18-
for it in info.Nt[ist]:
19-
Q[ist][it] = ND_list(info.Nl[it])
16+
Q = dict()
17+
for it in info.Nt[ist]:
18+
Q[it] = ND_list(info.Nl[it])
2019

21-
for ist in range(info.Nst):
22-
for it in info.Nt[ist]:
23-
for il in range(info.Nl[it]):
24-
Q[ist][it][il] = torch_complex.mm( QI[ist][it][il], C[it][il] ).view(info.Nb[ist],-1)
20+
for it in info.Nt[ist]:
21+
for il in range(info.Nl[it]):
22+
Q[it][il] = torch_complex.mm( QI[it][il], C[it][il] ).view(info.Nb[ist],-1)
2523
return Q
2624

2725

2826

29-
def cal_S(self,SI,C,info):
27+
def cal_S(self,SI,C,info,ist):
3028
"""
3129
<\phi|\phi> = <\phi|jY> * <jY|jY> * <jY|\phi>
32-
S[ist][it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2]
33-
= sum_{ie1 ie2} C^*[it1][il1][ie1,iu1] * SI[ist][it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] * C[it2][[il2][ie2,iu2]
30+
S[it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2]
31+
= sum_{ie1 ie2} C^*[it1][il1][ie1,iu1] * SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] * C[it2][[il2][ie2,iu2]
3432
"""
35-
S = ND_list(info.Nst,element="dict()")
36-
for ist in range(info.Nst):
37-
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
38-
S[ist][it1,it2] = ND_list(info.Nl[it1],info.Nl[it2])
33+
S = dict()
34+
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
35+
S[it1,it2] = ND_list(info.Nl[it1],info.Nl[it2])
3936

40-
for ist in range(info.Nst):
41-
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
42-
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
43-
# SI_C[ia1*im1*ie1*ia2*im2,iu2]
44-
SI_C = torch_complex.mm(
45-
SI[ist][it1,it2][il1][il2].view(-1,info.Ne[it2]),
46-
C[it2][il2] )
47-
# SI_C[ia1*im1,ie1,ia2*im2*iu2]
48-
SI_C = SI_C.view( info.Na[ist][it1]*info.Nm(il1), info.Ne[it1], -1 )
49-
# Ct[iu1,ie1]
50-
Ct = C[it1][il1].t()
51-
C_mm = functools.partial(torch_complex.mm,Ct)
52-
# C_SI_C[ia1*im1][iu1,ia2*im2*iu2]
53-
C_SI_C = list(map( C_mm, SI_C ))
54-
# C_SI_C[ia1*im1*iu1,ia2*im2*iu2]
55-
C_SI_C = torch_complex.cat( C_SI_C, dim=0 )
37+
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
38+
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
39+
# SI_C[ia1*im1*ie1*ia2*im2,iu2]
40+
SI_C = torch_complex.mm(
41+
SI[it1,it2][il1][il2].view(-1,info.Ne[it2]),
42+
C[it2][il2] )
43+
# SI_C[ia1*im1,ie1,ia2*im2*iu2]
44+
SI_C = SI_C.view( info.Na[ist][it1]*info.Nm(il1), info.Ne[it1], -1 )
45+
# Ct[iu1,ie1]
46+
Ct = C[it1][il1].t()
47+
C_mm = functools.partial(torch_complex.mm,Ct)
48+
# C_SI_C[ia1*im1][iu1,ia2*im2*iu2]
49+
C_SI_C = list(map( C_mm, SI_C ))
50+
# C_SI_C[ia1*im1*iu1,ia2*im2*iu2]
51+
C_SI_C = torch_complex.cat( C_SI_C, dim=0 )
5652
#??? C_SI_C = C_SI_C.view(info.Na[ist][it1]*info.Nm(il1)*info.Nu[it1][il1],-1)
57-
S[ist][it1,it2][il1][il2] = C_SI_C
53+
S[it1,it2][il1][il2] = C_SI_C
5854
return S
5955

6056

6157

62-
def change_index_S(self,S,info): # S[ist][it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2]
58+
def change_index_S(self,S,info,ist): # S[it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2]
6359
"""
6460
<\phi|\phi>
65-
S_cat[ist][it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
61+
S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
6662
"""
67-
S_cat = ND_list(info.Nst)
68-
for ist in range(info.Nst):
69-
# S_s[it1][il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
70-
S_s = dict()
71-
for it1 in info.Nt[ist]:
72-
# S_st[it2][il1*ia1*im1*iu1,il2*ia2*im2*iu2]
73-
S_st = dict()
74-
for it2 in info.Nt[ist]:
75-
# S_stt[il1][ia1*im1*iu1,il2*ia2*im2*iu2]
76-
S_stt = ND_list(info.Nl[it1])
77-
for il1 in range(info.Nl[it1]):
78-
S_stt[il1] = torch_complex.cat( S[ist][it1,it2][il1], dim=1 )
79-
S_st[it2] = torch_complex.cat( S_stt, dim=0 )
80-
S_s[it1] = torch_complex.cat( list(S_st.values()), dim=1 )
81-
# S_cat[ist][it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
82-
S_cat[ist] = torch_complex.cat( list(S_s.values()), dim=0 )
63+
# S_[it1][il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
64+
S_ = dict()
65+
for it1 in info.Nt[ist]:
66+
# S_t[it2][il1*ia1*im1*iu1,il2*ia2*im2*iu2]
67+
S_t = dict()
68+
for it2 in info.Nt[ist]:
69+
# S_tt[il1][ia1*im1*iu1,il2*ia2*im2*iu2]
70+
S_tt = ND_list(info.Nl[it1])
71+
for il1 in range(info.Nl[it1]):
72+
S_tt[il1] = torch_complex.cat( S[it1,it2][il1], dim=1 )
73+
S_t[it2] = torch_complex.cat( S_tt, dim=0 )
74+
S_[it1] = torch_complex.cat( list(S_t.values()), dim=1 )
75+
# S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
76+
S_cat = torch_complex.cat( list(S_.values()), dim=0 )
8377
return S_cat
8478

8579

8680

87-
def change_index_Q(self,Q,info): # Q[ist][it][il][ib,ia*im*iu]
81+
def change_index_Q(self,Q,info,ist): # Q[it][il][ib,ia*im*iu]
8882
"""
8983
<\psi|\phi>
90-
Q_cat[ist][ib,it*il*ia*im*iu]
84+
Q_cat[ib,it*il*ia*im*iu]
9185
"""
92-
Q_cat = ND_list(info.Nst)
93-
for ist in range(info.Nst):
94-
# Q_b[ib][0,it*il*ia*im*iu]
95-
Q_b = ND_list(info.Nb[ist])
96-
for ib in range(info.Nb[ist]):
97-
# Q_s[it][il*ia*im*iu]
98-
Q_s = dict()
99-
for it in info.Nt[ist]:
100-
# Q_ts[il][ia*im*iu]
101-
Q_ts = [ Q_stl[ib] for Q_stl in Q[ist][it] ]
102-
Q_s[it] = torch_complex.cat(Q_ts)
103-
Q_b[ib] = torch_complex.cat(list(Q_s.values())).view(1,-1)
104-
# Q_cat[ist][ib,it*il*ia*im*iu]
105-
Q_cat[ist] = torch_complex.cat( Q_b, dim=0 )
86+
# Q_b[ib][0,it*il*ia*im*iu]
87+
Q_b = ND_list(info.Nb[ist])
88+
for ib in range(info.Nb[ist]):
89+
# Q_[it][il*ia*im*iu]
90+
Q_ = dict()
91+
for it in info.Nt[ist]:
92+
# Q_ts[il][ia*im*iu]
93+
Q_ts = [ Q_tl[ib] for Q_tl in Q[it] ]
94+
Q_[it] = torch_complex.cat(Q_ts)
95+
Q_b[ib] = torch_complex.cat(list(Q_.values())).view(1,-1)
96+
# Q_cat[ib,it*il*ia*im*iu]
97+
Q_cat = torch_complex.cat( Q_b, dim=0 )
10698
return Q_cat
10799

108100

109101

110-
def cal_coef(self,Q,S,info):
111-
# Q[ist][ib,it*il*ia*im*iu]
112-
# S[ist][it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
102+
def cal_coef(self,Q,S):
103+
# Q[ib,it*il*ia*im*iu]
104+
# S[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
113105
"""
114106
<\psi|\phi> * <\phi|\phi>^{-1}
115-
coef[ist][ib,it*il*ia*im*iu]
116-
= Q[ist][ib,it1*il1*ia1*im1*iu1] * S[ist]{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1}
107+
coef[ib,it*il*ia*im*iu]
108+
= Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1}
117109
"""
118-
coef = ND_list(info.Nst)
119-
for ist in range(info.Nst):
120-
S_I = torch_complex.inverse(S[ist])
121-
coef[ist] = torch_complex.mm(Q[ist], S_I)
110+
S_I = torch_complex.inverse(S)
111+
coef = torch_complex.mm(Q, S_I)
122112
return coef
123113

124114

125115

126-
def cal_V(self,Q,S,info,V_info):
116+
def cal_V(self,Q,S,V_info):
127117
"""
128118
<\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
129-
V[ist][ib]
119+
V[ib]
130120
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
131-
Q[ist][ib,it1*il1*ia1*im1*iu1] * S[ist]{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ist][ib,it2*il2*ia2*im2*iu2]
132-
V[ist][ib1,ib2]
121+
Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib,it2*il2*ia2*im2*iu2]
122+
V[ib1,ib2]
133123
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
134-
Q[ist][ib1,it1*il1*ia1*im1*iu1] * S[ist]{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ist][ib2,it2*il2*ia2*im2*iu2]
124+
Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
135125
"""
136-
coef = self.cal_coef(Q,S,info) # coef[ist][ib,it*il*ia*im*iu]
137-
138-
V = ND_list(info.Nst)
139-
Z = ND_list(info.Nst)
140-
for ist in range(info.Nst):
141-
# V[ist][ib1,ib2]
142-
V_tmp = torch_complex.mm( coef[ist], Q[ist].t().conj() ).real
143-
if V_info["same_band"]: V[ist] = V_tmp.diag().sqrt()
144-
else: V[ist] = V_tmp.sqrt()
126+
coef = self.cal_coef(Q,S) # coef[ib,it*il*ia*im*iu]
127+
128+
# V[ib1,ib2]
129+
V_tmp = torch_complex.mm( coef, Q.t().conj() ).real
130+
if V_info["same_band"]: V = V_tmp.diag().sqrt()
131+
else: V = V_tmp.sqrt()
145132
return V
146133

147134

148135

149-
def cal_V_linear(self,Q,S,Q_linear,S_linear,V,info,V_info):
150-
# Q[ist][ib,it*il*ia*im*iu]
151-
# S[ist][it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
136+
def cal_V_linear(self,Q,S,Q_linear,S_linear,V,V_info):
137+
# Q[ib,it*il*ia*im*iu]
138+
# S[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
152139
"""
153-
V_linear[ist][ib]
154-
V_linear[ist][ib1,ib2]
140+
V_linear[ib]
141+
V_linear[ib1,ib2]
155142
"""
156-
coef = self.cal_coef(Q,S,info) # coef[ist][ib,it*il*ia*im*iu]
157-
158-
V_linear = ND_list(info.Nst)
159-
for ist in range(info.Nst):
160-
V_linear_1 = coef[ist].mm(S_linear[ist]).mm(coef[ist].t().conj()).real
161-
V_linear_2 = Q_linear[ist].mm(coef[ist].t().conj()).real
162-
V_linear_3 = coef[ist].mm(Q_linear[ist].t().conj()).real
163-
if V_info["same_band"]:
164-
V_linear_1 = V_linear_1.diag()
165-
V_linear_2 = V_linear_2.diag()
166-
V_linear_3 = V_linear_3.diag()
167-
Z = util.update0(V[ist])
168-
V_linear[ist] = (-V_linear_1/Z + V_linear_2 + V_linear_3) / Z
143+
coef = self.cal_coef(Q,S) # coef[ib,it*il*ia*im*iu]
144+
145+
V_linear_1 = coef.mm(S_linear).mm(coef.t().conj()).real
146+
V_linear_2 = Q_linear.mm(coef.t().conj()).real
147+
V_linear_3 = coef.mm(Q_linear.t().conj()).real
148+
if V_info["same_band"]:
149+
V_linear_1 = V_linear_1.diag()
150+
V_linear_2 = V_linear_2.diag()
151+
V_linear_3 = V_linear_3.diag()
152+
Z = util.update0(V)
153+
V_linear = (-V_linear_1/Z + V_linear_2 + V_linear_3) / Z
169154
return V_linear
170155

171156

0 commit comments

Comments
 (0)