11from util import ND_list
22import util
3- import torch_complex
43import functools
54import itertools
65import torch
@@ -19,7 +18,7 @@ def cal_Q(self,QI,C,info,ist):
1918
2019 for it in info .Nt [ist ]:
2120 for il in range (info .Nl [it ]):
22- Q [it ][il ] = torch_complex .mm ( QI [it ][il ], C [it ][il ] ).view (info .Nb [ist ],- 1 )
21+ Q [it ][il ] = torch .mm ( QI [it ][il ], C [it ][il ]. to ( torch . complex128 ) ).view (info .Nb [ist ],- 1 )
2322 return Q
2423
2524
@@ -37,18 +36,18 @@ def cal_S(self,SI,C,info,ist):
3736 for it1 ,it2 in itertools .product ( info .Nt [ist ], info .Nt [ist ] ):
3837 for il1 ,il2 in itertools .product ( range (info .Nl [it1 ]), range (info .Nl [it2 ]) ):
3938 # SI_C[ia1*im1*ie1*ia2*im2,iu2]
40- SI_C = torch_complex .mm (
39+ SI_C = torch .mm (
4140 SI [it1 ,it2 ][il1 ][il2 ].view (- 1 ,info .Ne [it2 ]),
42- C [it2 ][il2 ] )
41+ C [it2 ][il2 ]. to ( torch . complex128 ) )
4342 # SI_C[ia1*im1,ie1,ia2*im2*iu2]
4443 SI_C = SI_C .view ( info .Na [ist ][it1 ]* info .Nm (il1 ), info .Ne [it1 ], - 1 )
4544 # Ct[iu1,ie1]
46- Ct = C [it1 ][il1 ].t ()
47- C_mm = functools .partial (torch_complex .mm ,Ct )
45+ Ct = C [it1 ][il1 ].t (). to ( torch . complex128 )
46+ C_mm = functools .partial (torch .mm ,Ct )
4847 # C_SI_C[ia1*im1][iu1,ia2*im2*iu2]
4948 C_SI_C = list (map ( C_mm , SI_C ))
5049 # C_SI_C[ia1*im1*iu1,ia2*im2*iu2]
51- C_SI_C = torch_complex .cat ( C_SI_C , dim = 0 )
50+ C_SI_C = torch .cat ( C_SI_C , dim = 0 )
5251#??? C_SI_C = C_SI_C.view(info.Na[ist][it1]*info.Nm(il1)*info.Nu[it1][il1],-1)
5352 S [it1 ,it2 ][il1 ][il2 ] = C_SI_C
5453 return S
@@ -69,11 +68,11 @@ def change_index_S(self,S,info,ist): # S[it1,it2][il1][il2][ia1*im1*iu1,ia
6968 # S_tt[il1][ia1*im1*iu1,il2*ia2*im2*iu2]
7069 S_tt = ND_list (info .Nl [it1 ])
7170 for il1 in range (info .Nl [it1 ]):
72- S_tt [il1 ] = torch_complex .cat ( S [it1 ,it2 ][il1 ], dim = 1 )
73- S_t [it2 ] = torch_complex .cat ( S_tt , dim = 0 )
74- S_ [it1 ] = torch_complex .cat ( list (S_t .values ()), dim = 1 )
71+ S_tt [il1 ] = torch .cat ( S [it1 ,it2 ][il1 ], dim = 1 )
72+ S_t [it2 ] = torch .cat ( S_tt , dim = 0 )
73+ S_ [it1 ] = torch .cat ( list (S_t .values ()), dim = 1 )
7574 # S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
76- S_cat = torch_complex .cat ( list (S_ .values ()), dim = 0 )
75+ S_cat = torch .cat ( list (S_ .values ()), dim = 0 )
7776 return S_cat
7877
7978
@@ -91,10 +90,10 @@ def change_index_Q(self,Q,info,ist): # Q[it][il][ib,ia*im*iu]
9190 for it in info .Nt [ist ]:
9291 # Q_ts[il][ia*im*iu]
9392 Q_ts = [ Q_tl [ib ] for Q_tl in Q [it ] ]
94- Q_ [it ] = torch_complex .cat (Q_ts )
95- Q_b [ib ] = torch_complex .cat (list (Q_ .values ())).view (1 ,- 1 )
93+ Q_ [it ] = torch .cat (Q_ts )
94+ Q_b [ib ] = torch .cat (list (Q_ .values ())).view (1 ,- 1 )
9695 # Q_cat[ib,it*il*ia*im*iu]
97- Q_cat = torch_complex .cat ( Q_b , dim = 0 )
96+ Q_cat = torch .cat ( Q_b , dim = 0 )
9897 return Q_cat
9998
10099
@@ -107,8 +106,8 @@ def cal_coef(self,Q,S):
107106 coef[ib,it*il*ia*im*iu]
108107 = Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1}
109108 """
110- S_I = torch_complex .inverse (S )
111- coef = torch_complex .mm (Q , S_I )
109+ S_I = torch .inverse (S )
110+ coef = torch .mm (Q , S_I )
112111 return coef
113112
114113
@@ -122,7 +121,7 @@ def cal_V(self,coef,Q):
122121 = sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
123122 Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
124123 """
125- V = torch_complex .mm ( coef , Q .t ().conj () ).real
124+ V = torch .mm ( coef , Q .t ().conj () ).real
126125 return V
127126
128127
0 commit comments