Skip to content

Commit e1a82fc

Browse files
authored
Update preconditioned_stochastic_gradient_descent.py
1, replace trtrs with triangular_solve due to torch's API update 2, use torch.chain_matmul for things like A @ B @ C ...
1 parent d33d861 commit e1a82fc

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

preconditioned_stochastic_gradient_descent.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def update_precond_dense(Q, dxs, dgs, step=0.01):
2121
dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs])
2222

2323
a = Q.mm(dg)
24-
b = torch.trtrs(dx, Q.t(), upper=False)[0]
24+
b = torch.triangular_solve(dx, Q.t(), upper=False)[0]
2525

2626
grad = torch.triu(a.mm(a.t()) - b.mm(b.t()))
2727
step0 = step/(grad.abs().max() + _tiny)
@@ -91,7 +91,7 @@ def update_precond_kron(Ql, Qr, dX, dG, step=0.01):
9191
Qr = rho*Qr
9292

9393
A = Ql.mm( dG.mm( Qr.t() ) )
94-
Bt = torch.trtrs((torch.trtrs(dX.t(), Qr.t(), upper=False))[0].t(),
94+
Bt = torch.triangular_solve((torch.triangular_solve(dX.t(), Qr.t(), upper=False))[0].t(),
9595
Ql.t(), upper=False)[0]
9696

9797
grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))
@@ -110,10 +110,12 @@ def precond_grad_kron(Ql, Qr, Grad):
110110
Qr: (right side) Cholesky factor of preconditioner
111111
Grad: (matrix) gradient
112112
"""
113-
if Grad.shape[0] > Grad.shape[1]:
114-
return Ql.t().mm( Ql.mm( Grad.mm( Qr.t().mm(Qr) ) ) )
115-
else:
116-
return (((Ql.t().mm(Ql)).mm(Grad)).mm(Qr.t())).mm(Qr)
113+
#if Grad.shape[0] > Grad.shape[1]:
114+
# return Ql.t().mm( Ql.mm( Grad.mm( Qr.t().mm(Qr) ) ) )
115+
#else:
116+
# return (((Ql.t().mm(Ql)).mm(Grad)).mm(Qr.t())).mm(Qr)
117+
# replace it with chain matrix multiplication by lixilinx on Dec. 5, 2020
118+
return torch.chain_matmul(Ql.t(), Ql, Grad, Qr.t(), Qr)
117119

118120

119121

@@ -201,7 +203,7 @@ def update_precond_scaw(Ql, qr, dX, dG, step=0.01):
201203
qr = rho*qr
202204

203205
A = Ql.mm( dG*qr )
204-
Bt = torch.trtrs(dX/qr, Ql.t(), upper=False)[0]
206+
Bt = torch.triangular_solve(dX/qr, Ql.t(), upper=False)[0]
205207

206208
grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))
207209
grad2 = torch.sum(A*A, dim=0, keepdim=True) - torch.sum(Bt*Bt, dim=0, keepdim=True)
@@ -216,7 +218,9 @@ def precond_grad_scaw(Ql, qr, Grad):
216218
"""
217219
apply scaling-and-whitening preconditioner
218220
"""
219-
return (Ql.t().mm(Ql)).mm(Grad*(qr*qr))
221+
#return (Ql.t().mm(Ql)).mm(Grad*(qr*qr))
222+
# replace with chain product on Dec, 5, 2020 by lixilinx
223+
return torch.chain_matmul(Ql.t(), Ql, Grad*(qr*qr))
220224

221225

222226

@@ -259,23 +263,23 @@ def update_precond_splu(L12, l3, U12, u3, dxs, dgs, step=0.01):
259263
Qg1 = L1.mm(Ug1)
260264
Qg2 = L2.mm(Ug1) + l3*Ug2
261265
# inv(U^T)*dx
262-
iUtx1 = torch.trtrs(dx[:r], U1.t(), upper=False)[0]
266+
iUtx1 = torch.triangular_solve(dx[:r], U1.t(), upper=False)[0]
263267
iUtx2 = (dx[r:] - U2.t().mm(iUtx1))/u3
264268
# inv(Q^T)*dx
265269
iQtx2 = iUtx2/l3
266-
iQtx1 = torch.trtrs(iUtx1 - L2.t().mm(iQtx2), L1.t(), upper=True)[0]
270+
iQtx1 = torch.triangular_solve(iUtx1 - L2.t().mm(iQtx2), L1.t(), upper=True)[0]
267271
# L^T*Q*dg
268272
LtQg1 = L1.t().mm(Qg1) + L2.t().mm(Qg2)
269273
LtQg2 = l3*Qg2
270274
# P*dg
271275
Pg1 = U1.t().mm(LtQg1)
272276
Pg2 = U2.t().mm(LtQg1) + u3*LtQg2
273277
# inv(L)*inv(Q^T)*dx
274-
iLiQtx1 = torch.trtrs(iQtx1, L1, upper=False)[0]
278+
iLiQtx1 = torch.triangular_solve(iQtx1, L1, upper=False)[0]
275279
iLiQtx2 = (iQtx2 - L2.mm(iLiQtx1))/l3
276280
# inv(P)*dx
277281
iPx2 = iLiQtx2/u3
278-
iPx1 = torch.trtrs(iLiQtx1 - U2.mm(iPx2), U1, upper=True)[0]
282+
iPx1 = torch.triangular_solve(iLiQtx1 - U2.mm(iPx2), U1, upper=True)[0]
279283

280284
# update L
281285
grad1 = Qg1.mm(Qg1.t()) - iQtx1.mm(iQtx1.t())

0 commit comments

Comments
 (0)