Skip to content

Commit efd30d4

Browse files
committed
1. In opt_orb_pytorch_dpsi, change type of tensors from class ComplexTensor to torch.tensor(complex128)
1 parent 290c9b8 commit efd30d4

File tree

3 files changed

+26
-36
lines changed

3 files changed

+26
-36
lines changed

tools/opt_orb_pytorch_dpsi/IO/read_QSV.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from util import *
22
import torch
3-
import torch_complex
43
import itertools
54
import numpy as np
65
import re
@@ -92,22 +91,17 @@ def read_QI(info,ist,data):
9291
for it in info.Nt[ist]:
9392
QI[it] = ND_list(info.Nl[it])
9493
for il in range(info.Nl[it]):
95-
QI[it][il] = torch_complex.ComplexTensor(
96-
np.empty((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]),dtype=np.float64),
97-
np.empty((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]),dtype=np.float64) )
94+
QI[it][il] = torch.zeros((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]), dtype=torch.complex128)
9895
for ib in range(info.Nb[ist]):
9996
for it in info.Nt[ist]:
10097
for ia in range(info.Na[ist][it]):
10198
for il in range(info.Nl[it]):
10299
for im in range(info.Nm(il)):
103100
for ie in range(info.Ne[it]):
104-
QI[it][il].real[ib,ia,im,ie] = next(data)
105-
QI[it][il].imag[ib,ia,im,ie] = next(data)
101+
QI[it][il][ib,ia,im,ie] = complex(next(data), next(data))
106102
for it in info.Nt[ist]:
107103
for il in range(info.Nl[it]):
108-
QI[it][il] = torch_complex.ComplexTensor(
109-
torch.from_numpy(QI[it][il].real).view(-1,info.Ne[it]),
110-
torch.from_numpy(QI[it][il].imag).view(-1,info.Ne[it])).conj()
104+
QI[it][il] = QI[it][il].view(-1,info.Ne[it]).conj()
111105
return QI
112106

113107

@@ -118,9 +112,7 @@ def read_SI(info,ist,data):
118112
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
119113
SI[it1,it2] = ND_list(info.Nl[it1],info.Nl[it2])
120114
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
121-
SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
122-
np.empty((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]),dtype=np.float64),
123-
np.empty((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]),dtype=np.float64) )
115+
SI[it1,it2][il1][il2] = torch.zeros((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]), dtype=torch.complex128)
124116
for it1 in info.Nt[ist]:
125117
for ia1 in range(info.Na[ist][it1]):
126118
for il1 in range(info.Nl[it1]):
@@ -131,13 +123,12 @@ def read_SI(info,ist,data):
131123
for im2 in range(info.Nm(il2)):
132124
for ie1 in range(info.Ne[it1]):
133125
for ie2 in range(info.Ne[it2]):
134-
SI[it1,it2][il1][il2].real[ia1,im1,ie1,ia2,im2,ie2] = next(data)
135-
SI[it1,it2][il1][il2].imag[ia1,im1,ie1,ia2,im2,ie2] = next(data)
136-
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
137-
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
138-
SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
139-
torch.from_numpy(SI[it1,it2][il1][il2].real),
140-
torch.from_numpy(SI[it1,it2][il1][il2].imag))
126+
SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] = complex(next(data), next(data))
127+
# for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
128+
# for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
129+
# SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
130+
# torch.from_numpy(SI[it1,it2][il1][il2].real),
131+
# torch.from_numpy(SI[it1,it2][il1][il2].imag))
141132
return SI
142133

143134

tools/opt_orb_pytorch_dpsi/opt_orbital.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from util import ND_list
22
import util
3-
import torch_complex
43
import functools
54
import itertools
65
import torch
@@ -19,7 +18,7 @@ def cal_Q(self,QI,C,info,ist):
1918

2019
for it in info.Nt[ist]:
2120
for il in range(info.Nl[it]):
22-
Q[it][il] = torch_complex.mm( QI[it][il], C[it][il] ).view(info.Nb[ist],-1)
21+
Q[it][il] = torch.mm( QI[it][il], C[it][il].to(torch.complex128) ).view(info.Nb[ist],-1)
2322
return Q
2423

2524

@@ -37,18 +36,18 @@ def cal_S(self,SI,C,info,ist):
3736
for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
3837
for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
3938
# SI_C[ia1*im1*ie1*ia2*im2,iu2]
40-
SI_C = torch_complex.mm(
39+
SI_C = torch.mm(
4140
SI[it1,it2][il1][il2].view(-1,info.Ne[it2]),
42-
C[it2][il2] )
41+
C[it2][il2].to(torch.complex128) )
4342
# SI_C[ia1*im1,ie1,ia2*im2*iu2]
4443
SI_C = SI_C.view( info.Na[ist][it1]*info.Nm(il1), info.Ne[it1], -1 )
4544
# Ct[iu1,ie1]
46-
Ct = C[it1][il1].t()
47-
C_mm = functools.partial(torch_complex.mm,Ct)
45+
Ct = C[it1][il1].t().to(torch.complex128)
46+
C_mm = functools.partial(torch.mm,Ct)
4847
# C_SI_C[ia1*im1][iu1,ia2*im2*iu2]
4948
C_SI_C = list(map( C_mm, SI_C ))
5049
# C_SI_C[ia1*im1*iu1,ia2*im2*iu2]
51-
C_SI_C = torch_complex.cat( C_SI_C, dim=0 )
50+
C_SI_C = torch.cat( C_SI_C, dim=0 )
5251
#??? C_SI_C = C_SI_C.view(info.Na[ist][it1]*info.Nm(il1)*info.Nu[it1][il1],-1)
5352
S[it1,it2][il1][il2] = C_SI_C
5453
return S
@@ -69,11 +68,11 @@ def change_index_S(self,S,info,ist): # S[it1,it2][il1][il2][ia1*im1*iu1,ia
6968
# S_tt[il1][ia1*im1*iu1,il2*ia2*im2*iu2]
7069
S_tt = ND_list(info.Nl[it1])
7170
for il1 in range(info.Nl[it1]):
72-
S_tt[il1] = torch_complex.cat( S[it1,it2][il1], dim=1 )
73-
S_t[it2] = torch_complex.cat( S_tt, dim=0 )
74-
S_[it1] = torch_complex.cat( list(S_t.values()), dim=1 )
71+
S_tt[il1] = torch.cat( S[it1,it2][il1], dim=1 )
72+
S_t[it2] = torch.cat( S_tt, dim=0 )
73+
S_[it1] = torch.cat( list(S_t.values()), dim=1 )
7574
# S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
76-
S_cat = torch_complex.cat( list(S_.values()), dim=0 )
75+
S_cat = torch.cat( list(S_.values()), dim=0 )
7776
return S_cat
7877

7978

@@ -91,10 +90,10 @@ def change_index_Q(self,Q,info,ist): # Q[it][il][ib,ia*im*iu]
9190
for it in info.Nt[ist]:
9291
# Q_ts[il][ia*im*iu]
9392
Q_ts = [ Q_tl[ib] for Q_tl in Q[it] ]
94-
Q_[it] = torch_complex.cat(Q_ts)
95-
Q_b[ib] = torch_complex.cat(list(Q_.values())).view(1,-1)
93+
Q_[it] = torch.cat(Q_ts)
94+
Q_b[ib] = torch.cat(list(Q_.values())).view(1,-1)
9695
# Q_cat[ib,it*il*ia*im*iu]
97-
Q_cat = torch_complex.cat( Q_b, dim=0 )
96+
Q_cat = torch.cat( Q_b, dim=0 )
9897
return Q_cat
9998

10099

@@ -107,8 +106,8 @@ def cal_coef(self,Q,S):
107106
coef[ib,it*il*ia*im*iu]
108107
= Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1}
109108
"""
110-
S_I = torch_complex.inverse(S)
111-
coef = torch_complex.mm(Q, S_I)
109+
S_I = torch.inverse(S)
110+
coef = torch.mm(Q, S_I)
112111
return coef
113112

114113

@@ -122,7 +121,7 @@ def cal_V(self,coef,Q):
122121
= sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
123122
Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
124123
"""
125-
V = torch_complex.mm( coef, Q.t().conj() ).real
124+
V = torch.mm( coef, Q.t().conj() ).real
126125
return V
127126

128127

0 commit comments

Comments
 (0)