Skip to content

Commit 290c9b8

Browse files
committed
1. In opt_orb_pytorch_dpsi, move cal_coef() to main() from cal_V() and cal_V_linear()
1 parent 0c662b2 commit 290c9b8

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ def main():
6060

6161
Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info,ist),info,ist)
6262
S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info,ist),info,ist)
63-
V = opt_orb.cal_V(Q,S)
63+
coef = opt_orb.cal_coef(Q,S)
64+
V = opt_orb.cal_V(coef,Q)
6465
V_origin = opt_orb.cal_V_origin(V,V_info)
6566

6667
if "linear" in file_list.keys():
6768
V_linear = [None] * len(file_list["linear"])
6869
for i in range(len(file_list["linear"])):
6970
Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i][ist],C,info,ist),info,ist)
7071
S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i][ist],C,info,ist),info,ist)
71-
V_linear[i] = opt_orb.cal_V_linear(Q,S,Q_linear,S_linear,V,V_info)
72+
V_linear[i] = opt_orb.cal_V_linear(coef,Q_linear,S_linear,V,V_info)
7273

7374
def cal_Spillage(V_delta):
7475
Spillage = (V_delta * weight[ist]).sum()
@@ -129,5 +130,5 @@ def cal_delta(VI, V):
129130
if __name__=="__main__":
130131
import sys
131132
np.set_printoptions(threshold=sys.maxsize, linewidth=10000)
132-
print( sys.version, flush=True )
133+
print( sys.version, flush=True )
133134
main()

tools/opt_orb_pytorch_dpsi/opt_orbital.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,15 @@ def cal_coef(self,Q,S):
113113

114114

115115

116-
def cal_V(self,Q,S):
116+
def cal_V(self,coef,Q):
117+
# coef[ib,it*il*ia*im*iu]
118+
# Q[ib,it*il*ia*im*iu]
117119
"""
118120
<\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
119121
V[ib1,ib2]
120122
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
121123
Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
122124
"""
123-
coef = self.cal_coef(Q,S) # coef[ib,it*il*ia*im*iu]
124-
125-
# V[ib1,ib2]
126125
V = torch_complex.mm( coef, Q.t().conj() ).real
127126
return V
128127

@@ -139,15 +138,15 @@ def cal_V_origin(self,V,V_info):
139138
return V_origin
140139

141140

142-
def cal_V_linear(self,Q,S,Q_linear,S_linear,V,V_info):
143-
# Q[ib,it*il*ia*im*iu]
144-
# S[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
141+
def cal_V_linear(self,coef,Q_linear,S_linear,V,V_info):
142+
# coef[ib,it*il*ia*im*iu]
143+
# Q_linear[ib,it*il*ia*im*iu]
144+
# S_linear[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
145+
# V[ib1,ib2]
145146
"""
146147
V_linear[ib]
147148
V_linear[ib1,ib2]
148-
"""
149-
coef = self.cal_coef(Q,S) # coef[ib,it*il*ia*im*iu]
150-
149+
"""
151150
V_linear_1 = coef.mm(S_linear).mm(coef.t().conj()).real
152151
V_linear_2 = Q_linear.mm(coef.t().conj()).real
153152
V_linear_3 = coef.mm(Q_linear.t().conj()).real

0 commit comments

Comments
 (0)