Skip to content

Commit 44561f2

Browse files
authored
Merge pull request #88 from PeizeLin/develop
1. in opt_orb_pytorch_dpsi, delete ist in class Opt_Orbital
2 parents 3c7a680 + 97faf88 commit 44561f2

File tree

9 files changed

+270
-227
lines changed

9 files changed

+270
-227
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import addict
2+
3+
def change_info(info_old, weight_old):
4+
info_stru = [None] * info_old.Nst
5+
for ist in range(len(info_stru)):
6+
info_stru[ist] = addict.Dict()
7+
for ist,Na in enumerate(info_old.Na):
8+
info_stru[ist].Na = Na
9+
for ist,weight in enumerate(weight_old):
10+
info_stru[ist].weight = weight
11+
info_stru[ist].Nb = weight.shape[0]
12+
13+
info_element = addict.Dict()
14+
for it,Nu in info_old.Nu.items():
15+
info_element[it].Nu = Nu
16+
info_element[it].Nl = len(Nu)
17+
for it,Rcut in info_old.Rcut.items():
18+
info_element[it].Rcut = Rcut
19+
for it,dr in info_old.dr.items():
20+
info_element[it].dr = dr
21+
for it,Ecut in info_old.Ecut.items():
22+
info_element[it].Ecut = Ecut
23+
for it,Ne in info_old.Ne.items():
24+
info_element[it].Ne = Ne
25+
26+
info_opt = addict.Dict()
27+
info_opt.lr = info_old.lr
28+
info_opt.cal_T = info_old.cal_T
29+
info_opt.cal_smooth = info_old.cal_smooth
30+
31+
return info_stru, info_element, info_opt

tools/opt_orb_pytorch_dpsi/IO/func_C.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
import torch
33
import numpy as np
44

5-
def random_C_init(info):
5+
def random_C_init(info_element):
66
""" C[it][il][ie,iu] <jY|\phi> """
77
C = dict()
8-
for it in info.Nt_all:
9-
C[it] = ND_list(info.Nl[it])
10-
for il in range(info.Nl[it]):
11-
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info.Ne[it], info.Nu[it][il])), dtype=torch.float64, requires_grad=True)
8+
for it in info_element.keys():
9+
C[it] = ND_list(info_element[it].Nl)
10+
for il in range(info_element[it].Nl):
11+
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info_element[it].Ne, info_element[it].Nu[il])), dtype=torch.float64, requires_grad=True)
1212
return C
1313

1414

1515

16-
def read_C_init(file_name,info):
16+
def read_C_init(file_name,info_element):
1717
""" C[it][il][ie,iu] <jY|\phi> """
18-
C = random_C_init(info)
18+
C = random_C_init(info_element)
1919

2020
with open(file_name,"r") as file:
2121

@@ -29,11 +29,12 @@ def read_C_init(file_name,info):
2929
while True:
3030
line = file.readline().strip()
3131
if line.startswith("Type"):
32-
it,il,iu = list(map(int,file.readline().split()));
33-
it=info.Nt_all[it-1]; iu-=1
32+
it,il,iu = file.readline().split();
33+
il = int(il)
34+
iu = int(iu)-1
3435
C_read_index.add((it,il,iu))
3536
line = file.readline().split()
36-
for ie in range(info.Ne[it]):
37+
for ie in range(info_element[it].Ne):
3738
if not line: line = file.readline().split()
3839
C[it][il].data[ie,iu] = float(line.pop(0))
3940
elif line.startswith("</Coefficient>"):
@@ -44,17 +45,17 @@ def read_C_init(file_name,info):
4445

4546

4647

47-
def copy_C(C,info):
48+
def copy_C(C,info_element):
4849
C_copy = dict()
49-
for it in info.Nt_all:
50-
C_copy[it] = ND_list(info.Nl[it])
51-
for il in range(info.Nl[it]):
50+
for it in info_element.keys():
51+
C_copy[it] = ND_list(info_element[it].Nl)
52+
for il in range(info_element[it].Nl):
5253
C_copy[it][il] = C[it][il].clone()
5354
return C_copy
5455

5556

5657

57-
def write_C(file_name,info,C,Spillage):
58+
def write_C(file_name,C,Spillage):
5859
with open(file_name,"w") as file:
5960
print("<Coefficient>", file=file)
6061
#print("\tTotal number of radial orbitals.", file=file)
@@ -70,7 +71,7 @@ def write_C(file_name,info,C,Spillage):
7071
for il,C_tl in enumerate(C_t):
7172
for iu in range(C_tl.size()[1]):
7273
print("\tType\tL\tZeta-Orbital", file=file)
73-
print(f"\t {info.Nt_all.index(it)+1} \t{il}\t {iu+1}", file=file)
74+
print(f"\t {it} \t{il}\t {iu+1}", file=file)
7475
for ie in range(C_tl.size()[0]):
7576
print("\t", '%18.14f'%C_tl[ie,iu].item(), file=file)
7677
print("</Coefficient>", file=file)

tools/opt_orb_pytorch_dpsi/IO/print_orbital.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
## 'Fl': 114, 'Uup': 115, 'Lv': 116, 'Uus': 117, 'Uuo': 118
2424
}
2525

26-
def print_orbital(orb,info):
26+
def print_orbital(orb,info_element):
2727
""" orb[it][il][iu][r] """
2828
for it,orb_t in orb.items():
2929
#with open("orb_{0}.dat".format(it),"w") as file:
3030
with open("ORBITAL_{0}U.dat".format( periodtable[it] ),"w") as file:
31-
print_orbital_head(file,info,it)
31+
print_orbital_head(file,info_element,it)
3232
for il,orb_tl in enumerate(orb_t):
3333
for iu,orb_tlu in enumerate(orb_tl):
3434
print(""" Type L N""",file=file)
@@ -52,19 +52,19 @@ def plot_orbital(orb,Rcut,dr):
5252
print(file=file)
5353

5454

55-
def print_orbital_head(file,info,it):
55+
def print_orbital_head(file,info_element,it):
5656
print( "---------------------------------------------------------------------------", file=file )
5757
print( "Element {0}".format(it), file=file )
58-
print( "Energy Cutoff(Ry) {0}".format(info.Ecut[it]), file=file )
59-
print( "Radius Cutoff(a.u.) {0}".format(info.Rcut[it]), file=file )
60-
print( "Lmax {0}".format(info.Nl[it]-1), file=file )
58+
print( "Energy Cutoff(Ry) {0}".format(info_element[it].Ecut), file=file )
59+
print( "Radius Cutoff(a.u.) {0}".format(info_element[it].Rcut), file=file )
60+
print( "Lmax {0}".format(info_element[it].Nl-1), file=file )
6161
l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1)))
62-
for il,iu in enumerate(info.Nu[it]):
62+
for il,iu in enumerate(info_element[it].Nu):
6363
print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file )
6464
print( "---------------------------------------------------------------------------", file=file )
6565
print( "SUMMARY END", file=file )
6666
print( file=file )
67-
print( "Mesh {0}".format(int(info.Rcut[it]/info.dr[it])+1), file=file )
68-
print( "dr {0}".format(info.dr[it]), file=file )
67+
print( "Mesh {0}".format(int(info_element[it].Rcut/info_element[it].dr)+1), file=file )
68+
print( "dr {0}".format(info_element[it].dr), file=file )
6969

7070

tools/opt_orb_pytorch_dpsi/IO/read_QSV.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from util import *
22
import torch
3-
import torch_complex
43
import itertools
54
import numpy as np
65
import re
@@ -72,6 +71,7 @@ def read_file(info,file_list,V_info):
7271
print("read VI:",ist_true,ik)
7372
vi = read_VI(info_true,V_info,ist_true,data)
7473
VI.append( vi )
74+
print()
7575

7676
info_all = copy.deepcopy(info)
7777
info_all.Nst = sum(info_true.Nk,0)
@@ -92,22 +92,17 @@ def read_QI(info,ist,data):
9292
for it in info.Nt[ist]:
9393
QI[it] = ND_list(info.Nl[it])
9494
for il in range(info.Nl[it]):
95-
QI[it][il] = torch_complex.ComplexTensor(
96-
np.empty((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]),dtype=np.float64),
97-
np.empty((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]),dtype=np.float64) )
95+
QI[it][il] = torch.zeros((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]), dtype=torch.complex128)
9896
for ib in range(info.Nb[ist]):
9997
for it in info.Nt[ist]:
10098
for ia in range(info.Na[ist][it]):
10199
for il in range(info.Nl[it]):
102100
for im in range(info.Nm(il)):
103101
for ie in range(info.Ne[it]):
104-
QI[it][il].real[ib,ia,im,ie] = next(data)
105-
QI[it][il].imag[ib,ia,im,ie] = next(data)
102+
QI[it][il][ib,ia,im,ie] = complex(next(data), next(data))
106103
for it in info.Nt[ist]:
107104
for il in range(info.Nl[it]):
108-
QI[it][il] = torch_complex.ComplexTensor(
109-
torch.from_numpy(QI[it][il].real).view(-1,info.Ne[it]),
110-
torch.from_numpy(QI[it][il].imag).view(-1,info.Ne[it])).conj()
105+
QI[it][il] = QI[it][il].view(-1,info.Ne[it]).conj()
111106
return QI
112107

113108

@@ -118,9 +113,7 @@ def read_SI(info,ist,data):
118113
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
119114
SI[it1,it2] = ND_list(info.Nl[it1],info.Nl[it2])
120115
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
121-
SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
122-
np.empty((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]),dtype=np.float64),
123-
np.empty((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]),dtype=np.float64) )
116+
SI[it1,it2][il1][il2] = torch.zeros((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]), dtype=torch.complex128)
124117
for it1 in info.Nt[ist]:
125118
for ia1 in range(info.Na[ist][it1]):
126119
for il1 in range(info.Nl[it1]):
@@ -131,13 +124,12 @@ def read_SI(info,ist,data):
131124
for im2 in range(info.Nm(il2)):
132125
for ie1 in range(info.Ne[it1]):
133126
for ie2 in range(info.Ne[it2]):
134-
SI[it1,it2][il1][il2].real[ia1,im1,ie1,ia2,im2,ie2] = next(data)
135-
SI[it1,it2][il1][il2].imag[ia1,im1,ie1,ia2,im2,ie2] = next(data)
136-
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
137-
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
138-
SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
139-
torch.from_numpy(SI[it1,it2][il1][il2].real),
140-
torch.from_numpy(SI[it1,it2][il1][il2].imag))
127+
SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] = complex(next(data), next(data))
128+
# for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
129+
# for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
130+
# SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
131+
# torch.from_numpy(SI[it1,it2][il1][il2].real),
132+
# torch.from_numpy(SI[it1,it2][il1][il2].imag))
141133
return SI
142134

143135

0 commit comments

Comments
 (0)