@@ -113,24 +113,30 @@ def cal_coef(self,Q,S):
113113
114114
115115
116- def cal_V (self ,Q ,S , V_info ):
116+ def cal_V (self ,Q ,S ):
117117 """
118118 <\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
119- V[ib]
120- = sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
121- Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib,it2*il2*ia2*im2*iu2]
122119 V[ib1,ib2]
123120 = sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
124121 Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
125122 """
126123 coef = self .cal_coef (Q ,S ) # coef[ib,it*il*ia*im*iu]
127124
128125 # V[ib1,ib2]
129- V_tmp = torch_complex .mm ( coef , Q .t ().conj () ).real
130- if V_info ["same_band" ]: V = V_tmp .diag ().sqrt ()
131- else : V = V_tmp .sqrt ()
126+ V = torch_complex .mm ( coef , Q .t ().conj () ).real
132127 return V
133-
128+
129+
130+ def cal_V_origin (self ,V ,V_info ):
131+ # V[ib1,ib2]
132+ """
133+ <\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
134+ V_origin[ib]
135+ V_origin[ib1,ib2]
136+ """
137+ if V_info ["same_band" ]: V_origin = V .diag ().sqrt ()
138+ else : V_origin = V .sqrt ()
139+ return V_origin
134140
135141
136142 def cal_V_linear (self ,Q ,S ,Q_linear ,S_linear ,V ,V_info ):
@@ -149,7 +155,9 @@ def cal_V_linear(self,Q,S,Q_linear,S_linear,V,V_info):
149155 V_linear_1 = V_linear_1 .diag ()
150156 V_linear_2 = V_linear_2 .diag ()
151157 V_linear_3 = V_linear_3 .diag ()
152- Z = util .update0 (V )
158+ if V_info ["same_band" ]: Z = V .diag ().sqrt ()
159+ else : Z = V .sqrt ()
160+ Z = util .update0 (Z )
153161 V_linear = (- V_linear_1 / Z + V_linear_2 + V_linear_3 ) / Z
154162 return V_linear
155163
0 commit comments