Skip to content

Commit 0c662b2

Browse files
committed
1. In opt_orb_pytorch_dpsi, add V as <\psi|P|\psi>, change V to V_origin.
1 parent 922be2c commit 0c662b2

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def main():
2424

2525
weight = IO.cal_weight.cal_weight(weight_info, V_info["same_band"], file_list["origin"])
2626

27-
QI,SI,VI,info = IO.read_QSV.read_file(info_true,file_list["origin"],V_info)
27+
QI,SI,VI_origin,info = IO.read_QSV.read_file(info_true,file_list["origin"],V_info)
2828
print(info, flush=True)
2929
if "linear" in file_list.keys():
3030
QI_linear, SI_linear, VI_linear, info_linear = list(zip(*( IO.read_QSV.read_file(info_true,file,V_info) for file in file_list["linear"] )))
@@ -60,7 +60,8 @@ def main():
6060

6161
Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info,ist),info,ist)
6262
S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info,ist),info,ist)
63-
V = opt_orb.cal_V(Q,S,V_info)
63+
V = opt_orb.cal_V(Q,S)
64+
V_origin = opt_orb.cal_V_origin(V,V_info)
6465

6566
if "linear" in file_list.keys():
6667
V_linear = [None] * len(file_list["linear"])
@@ -76,7 +77,7 @@ def cal_Spillage(V_delta):
7677
def cal_delta(VI, V):
7778
return ((VI[ist]-V)/util.update0(VI[ist])).abs() # abs or **2?
7879

79-
Spillage += 2*cal_Spillage(cal_delta(VI,V))
80+
Spillage += 2*cal_Spillage(cal_delta(VI_origin,V_origin))
8081
if "linear" in file_list.keys():
8182
for i in range(len(file_list["linear"])):
8283
Spillage += cal_Spillage(cal_delta(VI_linear[i],V_linear[i]))

tools/opt_orb_pytorch_dpsi/opt_orbital.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,24 +113,30 @@ def cal_coef(self,Q,S):
113113

114114

115115

116-
def cal_V(self,Q,S,V_info):
116+
def cal_V(self,Q,S):
117117
"""
118118
<\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
119-
V[ib]
120-
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
121-
Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib,it2*il2*ia2*im2*iu2]
122119
V[ib1,ib2]
123120
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
124121
Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
125122
"""
126123
coef = self.cal_coef(Q,S) # coef[ib,it*il*ia*im*iu]
127124

128125
# V[ib1,ib2]
129-
V_tmp = torch_complex.mm( coef, Q.t().conj() ).real
130-
if V_info["same_band"]: V = V_tmp.diag().sqrt()
131-
else: V = V_tmp.sqrt()
126+
V = torch_complex.mm( coef, Q.t().conj() ).real
132127
return V
133-
128+
129+
130+
def cal_V_origin(self,V,V_info):
131+
# V[ib1,ib2]
132+
"""
133+
<\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
134+
V_origin[ib]
135+
V_origin[ib1,ib2]
136+
"""
137+
if V_info["same_band"]: V_origin = V.diag().sqrt()
138+
else: V_origin = V.sqrt()
139+
return V_origin
134140

135141

136142
def cal_V_linear(self,Q,S,Q_linear,S_linear,V,V_info):
@@ -149,7 +155,9 @@ def cal_V_linear(self,Q,S,Q_linear,S_linear,V,V_info):
149155
V_linear_1 = V_linear_1.diag()
150156
V_linear_2 = V_linear_2.diag()
151157
V_linear_3 = V_linear_3.diag()
152-
Z = util.update0(V)
158+
if V_info["same_band"]: Z = V.diag().sqrt()
159+
else: Z = V.sqrt()
160+
Z = util.update0(Z)
153161
V_linear = (-V_linear_1/Z + V_linear_2 + V_linear_3) / Z
154162
return V_linear
155163

0 commit comments

Comments
 (0)