Skip to content

Commit c237afc

Browse files
authored
Merge pull request #92 from PeizeLin/develop
1. In opt_orb_pytorch_dpsi, fix bug
2 parents 4aa1944 + 3c331ef commit c237afc

File tree

5 files changed

+32
-30
lines changed

5 files changed

+32
-30
lines changed

source/src_io/write_wfc_realspace.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ namespace Write_Wfc_Realspace
6060
// t0 t1 t2 t3 t4 t5 t6 t7
6161
// -------------------------------->
6262
// rank0 k0 k1 k2 k3 k4 k5
63-
// \ \ \ \ \ \
63+
// \ \ \ \ \ \
6464
// rank1 k0 k1 k2 k3 k4 k5
65-
// \ \ \ \ \ \
65+
// \ \ \ \ \ \
6666
// rank2 k0 k1 k2 k3 k4 k5
6767

6868

tools/opt_orb_pytorch_dpsi/IO/change_info.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def change_info(info_old, weight_old):
1313
info_stru[ist].Nb = weight.shape[0]
1414

1515
info_element = addict.Dict()
16+
for it_index,it in enumerate(info_old.Nt_all):
17+
info_element[it].index = it_index
1618
for it,Nu in info_old.Nu.items():
1719
info_element[it].Nu = Nu
1820
info_element[it].Nl = len(Nu)

tools/opt_orb_pytorch_dpsi/IO/read_QSV.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def read_QSV(info_stru, info_element, file_list, V_info):
6161
ist = 0
6262
for ist_true,file_name in enumerate(file_list):
6363
with open(file_name,"r") as file:
64-
Nk = int(re.compile(r"(\d)+\s+nks").search(file.read()).group(1))
64+
Nk = int(re.compile(r"(\d+)\s+nks").search(file.read()).group(1))
6565
with open(file_name,"r") as file:
6666
data = re.compile(r"<OVERLAP_Q>(.+)</OVERLAP_Q>", re.S).search(file.read())
6767
data = map(float,data.group(1).split())

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -128,29 +128,29 @@ def cal_delta(VI, V):
128128
for it,il,iu in C_read_index:
129129
C[it][il].grad[:,iu] = 0
130130
opt.step()
131-
# orbital.normalize(
132-
# orbital.generate_orbital(info_element,C,E),
133-
# {it:info_element[it].dr for it in info_element},
134-
# C, flag_norm_C=True)
135-
136-
orb = orbital.generate_orbital(info_element,C_old,E)
137-
if info_opt.cal_smooth:
138-
orbital.smooth_orbital(
139-
orb,
140-
{it:info_element[it].Rcut for it in info_element}, {it:info_element[it].dr for it in info_element},
141-
0.1)
142-
orbital.orth(
131+
#orbital.normalize(
132+
# orbital.generate_orbital(info_element,C,E),
133+
# {it:info_element[it].dr for it in info_element},
134+
# C, flag_norm_C=True)
135+
136+
orb = orbital.generate_orbital(info_element,C_old,E)
137+
if info_opt.cal_smooth:
138+
orbital.smooth_orbital(
143139
orb,
144-
{it:info_element[it].dr for it in info_element})
145-
IO.print_orbital.print_orbital(orb,info_element)
146-
IO.print_orbital.plot_orbital(
147-
orb,
148-
{it:info_element[it].Rcut for it in info_element},
149-
{it:info_element[it].dr for it in info_element})
150-
151-
IO.func_C.write_C("ORBITAL_RESULTS.txt",C_old,Spillage)
152-
153-
print("Time (PyTorch): %s\n"%(time.time()-time_start), flush=True )
140+
{it:info_element[it].Rcut for it in info_element}, {it:info_element[it].dr for it in info_element},
141+
0.1)
142+
orbital.orth(
143+
orb,
144+
{it:info_element[it].dr for it in info_element})
145+
IO.print_orbital.print_orbital(orb,info_element)
146+
IO.print_orbital.plot_orbital(
147+
orb,
148+
{it:info_element[it].Rcut for it in info_element},
149+
{it:info_element[it].dr for it in info_element})
150+
151+
IO.func_C.write_C("ORBITAL_RESULTS.txt",C_old,Spillage)
152+
153+
print("Time (PyTorch): %s\n"%(time.time()-time_start), flush=True )
154154

155155

156156
if __name__=="__main__":

tools/opt_orb_pytorch_dpsi/opt_orbital.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def change_index_S(self,S,info_stru,info_element): # S[it1,it2][il1][il2][
7171
S_tt[il1] = torch.cat( S[it1,it2][il1], dim=1 )
7272
S_t[it2] = torch.cat( S_tt, dim=0 )
7373
S_[it1] = torch.cat( list(S_t.values()), dim=1 )
74-
# S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
74+
# S_cat[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
7575
S_cat = torch.cat( list(S_.values()), dim=0 )
7676
return S_cat
7777

@@ -100,11 +100,11 @@ def change_index_Q(self,Q,info_stru): # Q[it][il][ib,ia*im*iu]
100100

101101
def cal_coef(self,Q,S):
102102
# Q[ib,it*il*ia*im*iu]
103-
# S[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
103+
# S[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
104104
"""
105105
<\psi|\phi> * <\phi|\phi>^{-1}
106106
coef[ib,it*il*ia*im*iu]
107-
= Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1}
107+
= Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]}^{-1}
108108
"""
109109
S_I = torch.inverse(S)
110110
coef = torch.mm(Q, S_I)
@@ -119,7 +119,7 @@ def cal_V(self,coef,Q):
119119
<\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
120120
V[ib1,ib2]
121121
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
122-
Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
122+
Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
123123
"""
124124
V = torch.mm( coef, Q.t().conj() ).real
125125
return V
@@ -140,7 +140,7 @@ def cal_V_origin(self,V,V_info):
140140
def cal_V_linear(self,coef,Q_linear,S_linear,V,V_info):
141141
# coef[ib,it*il*ia*im*iu]
142142
# Q_linear[ib,it*il*ia*im*iu]
143-
# S_linear[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
143+
# S_linear[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
144144
# V[ib1,ib2]
145145
"""
146146
V_linear[ib]

0 commit comments

Comments
 (0)