@@ -21,7 +21,7 @@ def update_precond_dense(Q, dxs, dgs, step=0.01):
2121 dg = torch .cat ([torch .reshape (g , [- 1 , 1 ]) for g in dgs ])
2222
2323 a = Q .mm (dg )
24- b = torch .trtrs (dx , Q .t (), upper = False )[0 ]
24+ b = torch .triangular_solve (dx , Q .t (), upper = False )[0 ]
2525
2626 grad = torch .triu (a .mm (a .t ()) - b .mm (b .t ()))
2727 step0 = step / (grad .abs ().max () + _tiny )
@@ -91,7 +91,7 @@ def update_precond_kron(Ql, Qr, dX, dG, step=0.01):
9191 Qr = rho * Qr
9292
9393 A = Ql .mm ( dG .mm ( Qr .t () ) )
94- Bt = torch .trtrs ((torch .trtrs (dX .t (), Qr .t (), upper = False ))[0 ].t (),
94+ Bt = torch .triangular_solve ((torch .triangular_solve (dX .t (), Qr .t (), upper = False ))[0 ].t (),
9595 Ql .t (), upper = False )[0 ]
9696
9797 grad1 = torch .triu (A .mm (A .t ()) - Bt .mm (Bt .t ()))
@@ -110,10 +110,12 @@ def precond_grad_kron(Ql, Qr, Grad):
110110 Qr: (right side) Cholesky factor of preconditioner
111111 Grad: (matrix) gradient
112112 """
113- if Grad .shape [0 ] > Grad .shape [1 ]:
114- return Ql .t ().mm ( Ql .mm ( Grad .mm ( Qr .t ().mm (Qr ) ) ) )
115- else :
116- return (((Ql .t ().mm (Ql )).mm (Grad )).mm (Qr .t ())).mm (Qr )
113+ #if Grad.shape[0] > Grad.shape[1]:
114+ # return Ql.t().mm( Ql.mm( Grad.mm( Qr.t().mm(Qr) ) ) )
115+ #else:
116+ # return (((Ql.t().mm(Ql)).mm(Grad)).mm(Qr.t())).mm(Qr)
117+ # replace it with chain matrix multiplication by lixilinx on Dec. 5, 2020
118+ return torch .chain_matmul (Ql .t (), Ql , Grad , Qr .t (), Qr )
117119
118120
119121
@@ -201,7 +203,7 @@ def update_precond_scaw(Ql, qr, dX, dG, step=0.01):
201203 qr = rho * qr
202204
203205 A = Ql .mm ( dG * qr )
204- Bt = torch .trtrs (dX / qr , Ql .t (), upper = False )[0 ]
206+ Bt = torch .triangular_solve (dX / qr , Ql .t (), upper = False )[0 ]
205207
206208 grad1 = torch .triu (A .mm (A .t ()) - Bt .mm (Bt .t ()))
207209 grad2 = torch .sum (A * A , dim = 0 , keepdim = True ) - torch .sum (Bt * Bt , dim = 0 , keepdim = True )
@@ -216,7 +218,9 @@ def precond_grad_scaw(Ql, qr, Grad):
216218 """
217219 apply scaling-and-whitening preconditioner
218220 """
219- return (Ql .t ().mm (Ql )).mm (Grad * (qr * qr ))
221+ #return (Ql.t().mm(Ql)).mm(Grad*(qr*qr))
222+ # replace with chain product on Dec, 5, 2020 by lixilinx
223+ return torch .chain_matmul (Ql .t (), Ql , Grad * (qr * qr ))
220224
221225
222226
@@ -259,23 +263,23 @@ def update_precond_splu(L12, l3, U12, u3, dxs, dgs, step=0.01):
259263 Qg1 = L1 .mm (Ug1 )
260264 Qg2 = L2 .mm (Ug1 ) + l3 * Ug2
261265 # inv(U^T)*dx
262- iUtx1 = torch .trtrs (dx [:r ], U1 .t (), upper = False )[0 ]
266+ iUtx1 = torch .triangular_solve (dx [:r ], U1 .t (), upper = False )[0 ]
263267 iUtx2 = (dx [r :] - U2 .t ().mm (iUtx1 ))/ u3
264268 # inv(Q^T)*dx
265269 iQtx2 = iUtx2 / l3
266- iQtx1 = torch .trtrs (iUtx1 - L2 .t ().mm (iQtx2 ), L1 .t (), upper = True )[0 ]
270+ iQtx1 = torch .triangular_solve (iUtx1 - L2 .t ().mm (iQtx2 ), L1 .t (), upper = True )[0 ]
267271 # L^T*Q*dg
268272 LtQg1 = L1 .t ().mm (Qg1 ) + L2 .t ().mm (Qg2 )
269273 LtQg2 = l3 * Qg2
270274 # P*dg
271275 Pg1 = U1 .t ().mm (LtQg1 )
272276 Pg2 = U2 .t ().mm (LtQg1 ) + u3 * LtQg2
273277 # inv(L)*inv(Q^T)*dx
274- iLiQtx1 = torch .trtrs (iQtx1 , L1 , upper = False )[0 ]
278+ iLiQtx1 = torch .triangular_solve (iQtx1 , L1 , upper = False )[0 ]
275279 iLiQtx2 = (iQtx2 - L2 .mm (iLiQtx1 ))/ l3
276280 # inv(P)*dx
277281 iPx2 = iLiQtx2 / u3
278- iPx1 = torch .trtrs (iLiQtx1 - U2 .mm (iPx2 ), U1 , upper = True )[0 ]
282+ iPx1 = torch .triangular_solve (iLiQtx1 - U2 .mm (iPx2 ), U1 , upper = True )[0 ]
279283
280284 # update L
281285 grad1 = Qg1 .mm (Qg1 .t ()) - iQtx1 .mm (iQtx1 .t ())
0 commit comments