55
66import numpy as np
77
8+ import pytensor .tensor as pt
89from pytensor import scalar as ps
910from pytensor .compile .builders import OpFromGraph
1011from pytensor .gradient import DisconnectedType
@@ -515,64 +516,43 @@ def perform(self, node, inputs, outputs):
515516
516517 def L_op (self , inputs , outputs , output_grads ):
517518 """
518- Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here:
519- https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803
520-
521- And from ..[2] which describes a solution in the square matrix case.
519+ Reverse-mode gradient of the QR function.
522520
523521 References
524522 ----------
525- .. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018)
526- .. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives
527- of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation",
528- Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454
523+ .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/
524+ .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2
529525 """
530526
531527 from pytensor .tensor .slinalg import solve_triangular
532528
533529 (A ,) = (cast (ptb .TensorVariable , x ) for x in inputs )
530+ m , n = A .shape
534531
535532 def _H (x : ptb .TensorVariable ):
536533 return x .conj ().mT
537534
538- def _copyutl (x : ptb .TensorVariable ):
539- return ptb .triu (x , k = 0 ) + _H (ptb .triu (x , k = 1 ))
535+ def _copyltu (x : ptb .TensorVariable ):
536+ return ptb .tril (x , k = 0 ) + _H (ptb .tril (x , k = - 1 ))
540537
541538 if self .mode == "raw" :
542539 raise NotImplementedError ("Gradient of qr not implemented for mode=raw" )
543540
544- elif self .mode == "complete" :
545- Q , R = (cast (ptb .TensorVariable , x ) for x in outputs )
546- qr_assert_op = Assert (
547- "Gradient of qr not implemented for m x n matrices with m != n and mode=complete"
548- )
549- R = qr_assert_op (R , ptm .eq (R .shape [0 ], R .shape [1 ]))
550-
551541 elif self .mode == "r" :
552- qr_assert_op = Assert (
553- "Gradient of qr not implemented for m x n matrices with m < n and mode=r"
554- )
555- A = qr_assert_op (A , ptm .ge (A .shape [0 ], A .shape [1 ]))
556542 # We need all the components of the QR to compute the gradient of A even if we only
557543 # use the upper triangular component in the cost function.
558544 Q , R = qr (A , mode = "reduced" )
545+ dQ = Q .zeros_like ()
546+ dR = cast (ptb .TensorVariable , output_grads [0 ])
559547
560548 else :
561549 Q , R = (cast (ptb .TensorVariable , x ) for x in outputs )
562- qr_assert_op = Assert (
563- "Gradient of qr not implemented for m x n matrices with m < n and mode=reduced"
564- )
565- R = qr_assert_op (R , ptm .eq (R .shape [0 ], R .shape [1 ]))
566-
567- if self .mode == "r" :
568- dR = cast (ptb .TensorVariable , output_grads [0 ])
569- R_dRt = R @ _H (dR )
570- M = ptb .tril (R_dRt - _H (R_dRt ), k = - 1 )
571- M_Rinvt = _H (solve_triangular (R , _H (M )))
572- A_bar = Q @ (M_Rinvt + dR )
573- return [A_bar ]
550+ if self .mode == "complete" :
551+ qr_assert_op = Assert (
552+ "Gradient of qr not implemented for m x n matrices with m > n and mode=complete"
553+ )
554+ R = qr_assert_op (R , ptm .le (m , n ))
574555
575- else :
576556 new_output_grads = []
577557 is_disconnected = [
578558 isinstance (x .type , DisconnectedType ) for x in output_grads
@@ -591,13 +571,22 @@ def _copyutl(x: ptb.TensorVariable):
591571
592572 (dQ , dR ) = (cast (ptb .TensorVariable , x ) for x in new_output_grads )
593573
594- Qt_dQ = _H (Q ) @ dQ
595- R_dRt = R @ _H (dR )
596- M = Q @ (ptb .tril (R_dRt - _H (R_dRt ), k = - 1 ) - _copyutl (Qt_dQ )) + dQ
597- M_Rinvt = _H (solve_triangular (R , _H (M )))
598- A_bar = M_Rinvt + Q @ dR
599-
600- return [A_bar ]
574+ # gradient expression when m >= n
575+ M = R @ _H (dR ) - _H (dQ ) @ Q
576+ K = dQ + Q @ _copyltu (M )
577+ A_bar_m_ge_n = _H (solve_triangular (R , _H (K )))
578+
579+ # gradient expression when m < n
580+ Y = A [:, m :]
581+ U = R [:, :m ]
582+ dU , dV = dR [:, :m ], dR [:, m :]
583+ dQ_Yt_dV = dQ + Y @ _H (dV )
584+ M = U @ _H (dU ) - _H (dQ_Yt_dV ) @ Q
585+ X_bar = _H (solve_triangular (U , _H (dQ_Yt_dV + Q @ _copyltu (M ))))
586+ Y_bar = Q @ dV
587+ A_bar_m_lt_n = pt .concatenate ([X_bar , Y_bar ], axis = 1 )
588+
589+ return [pt .switch (ptm .ge (m , n ), A_bar_m_ge_n , A_bar_m_lt_n )]
601590
602591
603592def qr (a , mode = "reduced" ):
0 commit comments