Skip to content

Commit e220ed0

Browse files
authored
Merge pull request #457 from abacusmodeling/develop
merge scalapack_gvx bugfix and related DFTU-bugfix
2 parents 551a1af + 6cff14f commit e220ed0

File tree

14 files changed

+1041
-6
lines changed

14 files changed

+1041
-6
lines changed

source/input.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,9 +1997,9 @@ bool Input::Read(const std::string &fn)
19971997
exit(0);
19981998
}
19991999

2000-
if(strcmp("genelpa", ks_solver.c_str())!=0)
2000+
if(strcmp("genelpa", ks_solver.c_str())!=0 && strcmp(ks_solver.c_str(),"scalapack_gvx")!=0 )
20012001
{
2002-
std::cout << " WRONG ARGUMENTS OF ks_solver in DFT+U routine, only genelpa is support " << std::endl;
2002+
std::cout << " WRONG ARGUMENTS OF ks_solver in DFT+U routine, only genelpa and scalapack_gvx are supportted " << std::endl;
20032003
exit(0);
20042004
}
20052005

source/src_lcao/DM_gamma.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,7 @@ void Local_Orbital_Charge::allocate_gamma(const Grid_Technique &gt)
267267
setAlltoallvParameter(GlobalC::ParaO.comm_2D, GlobalC::ParaO.blacs_ctxt, GlobalC::ParaO.nb);
268268

269269
// Peize Lin test 2019-01-16
270-
if (GlobalV::KS_SOLVER=="genelpa") //LiuXh add 2021-09-06, clear memory, _2d only used in genelpa solver
271-
{
272-
wfc_dm_2d.init();
273-
}
270+
wfc_dm_2d.init();
274271

275272
if(GlobalC::wf.start_wfc=="file")
276273
{
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import IO.read_istate
2+
import torch
3+
import re
4+
import functools
5+
import operator
6+
7+
def cal_weight(info_weight, flag_same_band, stru_file_list=None):
8+
""" weight[ist][ib] """
9+
10+
if "bands_file" in info_weight.keys():
11+
if "bands_range" in info_weight.keys():
12+
raise IOError('"bands_file" and "bands_range" only once')
13+
14+
weight = [] # weight[ist][ib]
15+
for weight_stru, file_name in zip(info_weight["stru"], info_weight["bands_file"]):
16+
occ = IO.read_istate.read_istate(file_name)
17+
weight += [occ_k * weight_stru for occ_k in occ]
18+
19+
elif "bands_range" in info_weight.keys():
20+
k_weight = read_k_weight(stru_file_list) # k_weight[ist][ik]
21+
nbands = read_nbands(stru_file_list) # nbands[ist]
22+
23+
st_weight = [] # st_weight[ist][ib]
24+
for weight_stru, bands_range, nbands_ist in zip(info_weight["stru"], info_weight["bands_range"], nbands):
25+
st_weight_tmp = torch.zeros((nbands_ist,))
26+
st_weight_tmp[:bands_range] = weight_stru
27+
st_weight.append( st_weight_tmp )
28+
29+
weight = [] # weight[ist][ib]
30+
for ist,_ in enumerate(k_weight):
31+
for ik,_ in enumerate(k_weight[ist]):
32+
weight.append(st_weight[ist] * k_weight[ist][ik])
33+
34+
else:
35+
raise IOError('"bands_file" and "bands_range" must once')
36+
37+
38+
if not flag_same_band:
39+
for ist,_ in enumerate(weight):
40+
weight[ist] = torch.tensordot(weight[ist], weight[ist], dims=0)
41+
42+
43+
normalization = functools.reduce(operator.add, map(torch.sum, weight), 0)
44+
weight = list(map(lambda x:x/normalization, weight))
45+
46+
return weight
47+
48+
49+
def read_k_weight(stru_file_list):
50+
""" weight[ist][ik] """
51+
weight = [] # weight[ist][ik]
52+
for file_name in stru_file_list:
53+
weight_k = [] # weight_k[ik]
54+
with open(file_name,"r") as file:
55+
data = re.compile(r"<WEIGHT_OF_KPOINTS>(.+)</WEIGHT_OF_KPOINTS>", re.S).search(file.read()).group(1).split("\n")
56+
for line in data:
57+
line = line.strip()
58+
if line:
59+
weight_k.append(float(line.split()[-1]))
60+
weight.append(weight_k)
61+
return weight
62+
63+
64+
def read_nbands(stru_file_list):
65+
""" nbands[ib] """
66+
nbands = []
67+
for file_name in stru_file_list:
68+
with open(file_name,"r") as file:
69+
nbands.append(int(re.compile(r"(\d+)\s+nbands").search(file.read()).group(1)))
70+
return nbands
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from util import *
2+
import torch
3+
import numpy as np
4+
5+
def random_C_init(info):
6+
""" C[it][il][ie,iu] <jY|\phi> """
7+
C = dict()
8+
for it in info.Nt_all:
9+
C[it] = ND_list(info.Nl[it])
10+
for il in range(info.Nl[it]):
11+
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info.Ne[it], info.Nu[it][il])), dtype=torch.float64, requires_grad=True)
12+
return C
13+
14+
15+
16+
def read_C_init(file_name,info):
17+
""" C[it][il][ie,iu] <jY|\phi> """
18+
C = random_C_init(info)
19+
20+
with open(file_name,"r") as file:
21+
22+
for line in file:
23+
if line.strip() == "<Coefficient>":
24+
line=None
25+
break
26+
ignore_line(file,1)
27+
28+
C_read_index = set()
29+
while True:
30+
line = file.readline().strip()
31+
if line.startswith("Type"):
32+
it,il,iu = list(map(int,file.readline().split()));
33+
it=info.Nt_all[it-1]; iu-=1
34+
C_read_index.add((it,il,iu))
35+
line = file.readline().split()
36+
for ie in range(info.Ne[it]):
37+
if not line: line = file.readline().split()
38+
C[it][il].data[ie,iu] = float(line.pop(0))
39+
elif line.startswith("</Coefficient>"):
40+
break;
41+
else:
42+
raise IOError("unknown line in read_C_init "+file_name+"\n"+line)
43+
return C, C_read_index
44+
45+
46+
47+
def copy_C(C,info):
48+
C_copy = dict()
49+
for it in info.Nt_all:
50+
C_copy[it] = ND_list(info.Nl[it])
51+
for il in range(info.Nl[it]):
52+
C_copy[it][il] = C[it][il].clone()
53+
return C_copy
54+
55+
56+
57+
def write_C(file_name,info,C):
58+
with open(file_name,"w") as file:
59+
print("<Coefficient>", file=file)
60+
print("\tTotal number of radial orbitals.", file=file)
61+
for it,C_t in C.items():
62+
for il,C_tl in enumerate(C_t):
63+
for iu in range(C_tl.size()[1]):
64+
print("\tType\tL\tZeta-Orbital", file=file)
65+
print(f"\t {info.Nt_all.index(it)+1} \t{il}\t {iu+1}", file=file)
66+
for ie in range(C_tl.size()[0]):
67+
print("\t", C_tl[ie,iu].item(), file=file)
68+
print("</Coefficient>", file=file)
69+
70+
71+
#def init_C(info):
72+
# """ C[it][il][ie,iu] """
73+
# C = ND_list(max(info.Nt))
74+
# for it in range(len(C)):
75+
# C[it] = ND_list(info.Nl[it])
76+
# for il in range(info.Nl[it]):
77+
# C[it][il] = torch.autograd.Variable( torch.Tensor( info.Ne, info.Nu[it][il] ), requires_grad = True )
78+
#
79+
# with open("C_init.dat","r") as file:
80+
# line = []
81+
# for it in range(len(C)):
82+
# for il in range(info.Nl[it]):
83+
# for i_n in range(info.Nu[it][il]):
84+
# for ie in range(info.Ne[it]):
85+
# if not line: line=file.readline().split()
86+
# C[it][il].data[ie,i_n] = float(line.pop(0))
87+
# return C
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
def print_V(V,file_name):
2+
""" V[ist][ib] """
3+
with open(file_name,"w") as file:
4+
for V_s in V:
5+
for V_sb in V_s:
6+
print(1-V_sb.item(),end="\t",file=file)
7+
print(file=file)
8+
9+
def print_S(S,file_name):
10+
""" S[ist][it1,it2][il1][il2][ia1*im1*in1,ia2*im2*in2] """
11+
with open(file_name,"w") as file:
12+
for ist,S_s in enumerate(S):
13+
for (it1,it2),S_tt in S_s.items():
14+
for il1,S_ttl in enumerate(S_tt):
15+
for il2,S_ttll in enumerate(S_ttl):
16+
print(ist,it1,it2,il1,il2,file=file)
17+
print(S_ttll.real.numpy(),file=file)
18+
print(S_ttll.imag.numpy(),"\n",file=file)
19+
20+
def print_Q(Q,file_name):
21+
""" Q[ist][it][il][ib,ia*im*iu] """
22+
with open(file_name,"w") as file:
23+
for ist,Q_s in enumerate(Q):
24+
for it,Q_st in Q_s.items():
25+
for il,Q_stl in enumerate(Q_st):
26+
print(ist,it,il,file=file)
27+
print(Q_stl.real.numpy(),file=file)
28+
print(Q_stl.imag.numpy(),"\n",file=file)
29+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
def print_orbital(orb,info):
2+
""" orb[it][il][iu][r] """
3+
for it,orb_t in orb.items():
4+
with open("orb_{0}.dat".format(it),"w") as file:
5+
print_orbital_head(file,info,it)
6+
for il,orb_tl in enumerate(orb_t):
7+
for iu,orb_tlu in enumerate(orb_tl):
8+
print(""" Type L N""",file=file)
9+
print(""" 0 {0} {1}""".format(il,iu),file=file)
10+
for ir,orb_tlur in enumerate(orb_tlu):
11+
print(orb_tlur,end="\t",file=file)
12+
if ir%4==3: print(file=file)
13+
print(file=file)
14+
15+
16+
def plot_orbital(orb,Rcut,dr):
17+
for it,orb_t in orb.items():
18+
with open("orb_{0}_plot.dat".format(it),"w") as file:
19+
Nr = int(Rcut[it]/dr[it])+1
20+
for ir in range(Nr):
21+
print(ir*dr[it],end="\t",file=file)
22+
for il,orb_tl in enumerate(orb_t):
23+
for orb_tlu in orb_tl:
24+
print(orb_tlu[ir],end="\t",file=file)
25+
print(file=file)
26+
27+
28+
def print_orbital_head(file,info,it):
29+
print( "---------------------------------------------------------------------------", file=file )
30+
print( "Element {0}".format(it), file=file )
31+
print( "Energy Cutoff(Ry) {0}".format(info.Ecut[it]), file=file )
32+
print( "Radius Cutoff(a.u.) {0}".format(info.Rcut[it]), file=file )
33+
print( "Lmax {0}".format(info.Nl[it]-1), file=file )
34+
l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1)))
35+
for il,iu in enumerate(info.Nu[it]):
36+
print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file )
37+
print( "---------------------------------------------------------------------------", file=file )
38+
print( "SUMMARY END", file=file )
39+
print( file=file )
40+
print( "Mesh {0}".format(int(info.Rcut[it]/info.dr[it])+1), file=file )
41+
print( "dr {0}".format(info.dr[it]), file=file )
42+
43+

0 commit comments

Comments
 (0)