Skip to content

Commit 2ab53a4

Browse files
committed
1. In opt_orb_pytorch_dpsi, add Nb_true, omit calculation of bands whose weight=0
1 parent 049de94 commit 2ab53a4

File tree

5 files changed

+85
-73
lines changed

5 files changed

+85
-73
lines changed

tools/opt_orb_pytorch_dpsi/IO/change_info.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ def change_info(info_old, weight_old):
1111
for ist,weight in enumerate(weight_old):
1212
info_stru[ist].weight = weight
1313
info_stru[ist].Nb = weight.shape[0]
14+
for ib in range(weight.shape[0], 0, -1):
15+
if weight[ib-1]>0:
16+
info_stru[ist].Nb_true = ib
17+
break
1418

1519
info_element = addict.Dict()
1620
for it_index,it in enumerate(info_old.Nt_all):
@@ -36,20 +40,36 @@ def change_info(info_old, weight_old):
3640

3741
"""
3842
info_stru =
39-
[{'Na': {'C': 2},
43+
[{'Na': {'C': 1},
4044
'Nb': 6,
41-
'weight': tensor([0.1250, 0.1250, 0.1150, 0.1150, 0.0200, 0.0000])},
42-
{'Na': {'C': 2},
45+
'Nb_true': 4,
46+
'weight': tensor([0.0333, 0.0111, 0.0111, 0.0111, 0.0000, 0.0000])},
47+
{'Na': {'C': 1},
4348
'Nb': 6,
44-
'weight': tensor([0.1250, 0.1250, 0.0896, 0.0896, 0.0707, 0.0000])}]
49+
'Nb_true': 2,
50+
'weight': tensor([0.0667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000])},
51+
{'Na': {'C': 1, 'O': 2},
52+
'Nb': 10,
53+
'Nb_true': 8,
54+
'weight': tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0000, 0.0000])}]
4555
4656
info_element =
47-
{'C': {'Ecut': 200,
48-
'Ne': 19,
49-
'Nl': 3,
50-
'Nu': [2, 2, 1],
51-
'Rcut': 6,
52-
'dr': 0.01}}
57+
{'C': {
58+
'Ecut': 200,
59+
'Ne': 19,
60+
'Nl': 3,
61+
'Nu': [2, 2, 1],
62+
'Rcut': 6,
63+
'dr': 0.01,
64+
'index': 0},
65+
'O': {
66+
'Ecut': 200,
67+
'Ne': 19,
68+
'Nl': 3,
69+
'Nu': [3, 2, 1],
70+
'Rcut': 6,
71+
'dr': 0.01,
72+
'index': 1}}
5373
5474
info_opt =
5575
{'cal_T': False,

tools/opt_orb_pytorch_dpsi/IO/read_QSV.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def read_QI(info_stru, info_element, data):
107107
QI[it][il][ib,ia,im,ie] = complex(next(data), next(data))
108108
for it in info_stru.Na.keys():
109109
for il in range(info_element[it].Nl):
110-
QI[it][il] = QI[it][il].view(-1,info_element[it].Ne).conj()
110+
QI[it][il] = QI[it][il][:info_stru.Nb_true,:,:,:].view(-1,info_element[it].Ne).conj()
111111
return QI
112112

113113

@@ -145,14 +145,16 @@ def read_VI(info_stru,V_info,ist,data):
145145
VI = np.empty(info_stru.Nb,dtype=np.float64)
146146
for ib in range(info_stru.Nb):
147147
VI.data[ib] = next(data)
148+
VI = VI[:info_stru.Nb_true]
148149
else:
149-
VI = np.ones(info_stru.Nb,dtype=np.float64)
150+
VI = np.ones(info_stru.Nb_true, dtype=np.float64)
150151
else:
151152
""" VI[ib1,ib2] <psi|psi> """
152153
if V_info["init_from_file"]:
153154
VI = np.empty((info_stru.Nb,info_stru.Nb),dtype=np.float64)
154155
for ib1,ib2 in itertools.product( range(info_stru.Nb), range(info_stru.Nb) ):
155156
VI[ib1,ib2] = next(data)
157+
VI = VI[info_stru.Nb_true, info_stru.Nb_true]
156158
else:
157-
VI = np.eye(info_stru.Nb,info_stru.Nb,dtype=np.float64)
159+
VI = np.eye(info_stru.Nb_true, info_stru.Nb_true, dtype=np.float64)
158160
return torch.from_numpy(VI)

tools/opt_orb_pytorch_dpsi/IO/read_json.py

Lines changed: 45 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,77 +16,67 @@ def read_json(file_name):
1616

1717
""" file_name
1818
{
19-
"file_list":
20-
{
21-
"origin":
22-
[
23-
"~/S2/OUT.ABACUS/S2.psi.dat",
24-
"~/SO2/OUT.ABACUS/SO2.psi.dat",
25-
"~/SO/OUT.ABACUS/SO.psi.dat"
19+
"file_list": {
20+
"origin": [
21+
"~/C_bulk/orb_matrix/test.0.dat",
22+
"~/CO2/orb_matrix/test.0.dat"
2623
],
27-
"linear":
28-
[
24+
"linear": [
2925
[
30-
"~/S2/OUT.ABACUS/S2.dpsi.dat",
31-
"~/SO2/OUT.ABACUS/SO2.dpsi.dat",
32-
"~/SO/OUT.ABACUS/SO.dpsi.dat"
26+
"~/C_bulk/orb_matrix/test.1.dat",
27+
"~/CO2/orb_matrix/test.1.dat"
3328
],
3429
[
35-
"~/S2/OUT.ABACUS/S2.ddpsi.dat",
36-
"~/SO2/OUT.ABACUS/SO2.ddpsi.dat",
37-
"~/SO/OUT.ABACUS/SO.ddpsi.dat"
38-
]
30+
"~/C_bulk/orb_matrix/test.2.dat",
31+
"~/CO2/orb_matrix/test.2.dat"
32+
],
3933
]
4034
},
41-
"info":
42-
{
43-
"Nt_all": ["S","O"],
44-
"Nu": {"S":[3,3,2],"O":[3,3,2]},
45-
"Rcut": {"S":10,"O":10},
46-
"dr": {"S":0.01,"O":0.01},
47-
"Ecut": {"S":100,"O":100},
48-
"lr": 0.01,
49-
"cal_T": true,
50-
"cal_smooth": true
35+
"info": {
36+
"Nt_all": [ "C", "O" ],
37+
"Nu": { "C":[2,2,1], "O":[3,2,1] },
38+
"Rcut": { "C":6, "O":6 },
39+
"dr": { "C":0.01, "O":0.01 },
40+
"Ecut": { "C":200, "O":200 },
41+
"lr": 0.01,
42+
"cal_T": false,
43+
"cal_smooth": false
5144
},
5245
"weight":
5346
{
54-
"stru": [2,3,1.5],
55-
"bands_range": [7,9,7], # "bands_range" and "bands_file" only once
47+
"stru": [1, 2.3],
48+
"bands_range": [10, 15], # "bands_range" and "bands_file" only once
5649
"bands_file":
5750
[
58-
"~/S2/OUT.ABACUS/istate.info",
59-
"~/SO2/OUT.ABACUS/istate.info",
60-
"~/SO/OUT.ABACUS/istate.info"
51+
"~/C_bulk/OUT.ABACUS/istate.info",
52+
"~/CO2/OUT.ABACUS/istate.info"
6153
]
54+
},
55+
"C_init_info": {
56+
"init_from_file": false,
57+
"C_init_file": "~/CO/ORBITAL_RESULTS.txt",
58+
"opt_C_read": false
6259
},
63-
"C_init_info":
64-
{
65-
"init_from_file": false,
66-
"C_init_file": "/public/udata/linpz/try/SIA/pytorch/test/many_atoms/SIA/ORBITAL_RESULTS.txt",
67-
"opt_C_read": false
68-
},
69-
"V_info":
70-
{
71-
"init_from_file": true,
72-
"same_band": false
60+
"V_info": {
61+
"init_from_file": true,
62+
"same_band": true
7363
}
7464
}
7565
"""
7666

7767
""" info
78-
Nt_all ["S", "O"]
79-
Nu {"S":[3,3,2], "O":[3,3,2]}
80-
Nb_true [7, 9, 7]
81-
weight [2, 3, 1.5]
82-
Rcut {"S":10, "O":10}
83-
dr {"S":0.01, "O":0.01}
84-
Ecut {"S":100, "O":100}
85-
lr 0.01
86-
Nl {"S":2, "O":2}
87-
Nst 3
88-
Nt [["S"], ["S","O"], ["S","O"]]
89-
Na [{"S":2}, {"S":1,"O":2}, {"S":1,"O":1}]
90-
Nb [7, 9, 7]
91-
Ne {"S":22, "O":19}
68+
Nt_all ['C', 'O']
69+
Nu {'C': [2, 2, 1], 'O': [3, 2, 1]}
70+
Rcut {'C': 6, 'O': 6}
71+
dr {'C': 0.01, 'O': 0.01}
72+
Ecut {'C': 200, 'O': 200}
73+
lr 0.01
74+
cal_T False
75+
cal_smooth False
76+
Nl {'C': 3, 'O': 3}
77+
Nst 3
78+
Nt [['C'], ['C'], ['C', 'O']]
79+
Na [{'C': 1}, {'C': 1}, {'C': 1, 'O': 2}]
80+
Nb [6, 6, 10]
81+
Ne {'C': 19, 'O': 19}
9282
"""

tools/opt_orb_pytorch_dpsi/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def main():
6868
print( '%5s'%"istep", "%20s"%"Spillage", flush=True )
6969

7070
loss_old = np.inf
71-
for istep in range(200):
71+
for istep in range(10000):
7272

7373
Spillage = 0
7474
for ist in range(len(info_stru)):
@@ -87,7 +87,7 @@ def main():
8787
V_linear[i] = opt_orb.cal_V_linear(coef,Q_linear,S_linear,V,V_info)
8888

8989
def cal_Spillage(V_delta):
90-
Spillage = (V_delta * weight[ist]).sum()
90+
Spillage = (V_delta * weight[ist][:info_stru[ist].Nb_true]).sum()
9191
return Spillage
9292

9393
def cal_delta(VI, V):

tools/opt_orb_pytorch_dpsi/opt_orbital.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def cal_Q(self,QI,C,info_stru,info_element):
1818

1919
for it in info_stru.Na.keys():
2020
for il in range(info_element[it].Nl):
21-
Q[it][il] = torch.mm( QI[it][il], C[it][il].to(torch.complex128) ).view(info_stru.Nb,-1)
21+
Q[it][il] = torch.mm( QI[it][il], C[it][il].to(torch.complex128) ).view(info_stru.Nb_true,-1)
2222
return Q
2323

2424

@@ -83,8 +83,8 @@ def change_index_Q(self,Q,info_stru): # Q[it][il][ib,ia*im*iu]
8383
Q_cat[ib,it*il*ia*im*iu]
8484
"""
8585
# Q_b[ib][0,it*il*ia*im*iu]
86-
Q_b = ND_list(info_stru.Nb)
87-
for ib in range(info_stru.Nb):
86+
Q_b = ND_list(info_stru.Nb_true)
87+
for ib in range(info_stru.Nb_true):
8888
# Q_[it][il*ia*im*iu]
8989
Q_ = dict()
9090
for it in info_stru.Na.keys():

0 commit comments

Comments
 (0)