@@ -512,6 +512,78 @@ def perform(self, node, inputs, outputs):
512512 else :
513513 outputs [0 ][0 ] = res
514514
515+ def L_op (self , inputs , outputs , output_grads ):
516+ """
517+ Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here:
518+ https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803
519+
520+ And from ..[2] which describes a solution in the square matrix case.
521+
522+ References
523+ ----------
524+ .. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018)
525+ .. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives
526+ of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation",
527+ Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454
528+ """
529+
530+ (A ,) = (cast (ptb .TensorVariable , x ) for x in inputs )
531+ * _ , m , n = A .type .shape
532+
533+ def _H (x : ptb .TensorVariable ):
534+ return x .conj ().T
535+
536+ def _copyutl (x : ptb .TensorVariable ):
537+ return ptb .triu (x , k = 0 ) + _H (ptb .triu (x , k = 1 ))
538+
539+ if self .mode == "raw" or (self .mode == "complete" and m != n ):
540+ raise NotImplementedError ("Gradient of qr not implemented" )
541+
542+ elif m < n :
543+ raise NotImplementedError (
544+ "Gradient of qr not implemented for m x n matrices with m < n"
545+ )
546+
547+ elif self .mode == "r" :
548+ # We need all the components of the QR to compute the gradient of A even if we only
549+ # use the upper triangular component in the cost function.
550+ Q , R = qr (A , mode = "reduced" )
551+ dR = cast (ptb .TensorVariable , output_grads [0 ])
552+ R_dRt = R @ _H (dR )
553+ Rinvt = _H (inv (R ))
554+ A_bar = Q @ ((ptb .tril (R_dRt - _H (R_dRt ), k = - 1 )) @ Rinvt + dR )
555+ return [A_bar ]
556+
557+ else :
558+ Q , R = (cast (ptb .TensorVariable , x ) for x in outputs )
559+
560+ new_output_grads = []
561+ is_disconnected = [
562+ isinstance (x .type , DisconnectedType ) for x in output_grads
563+ ]
564+ if all (is_disconnected ):
565+ # This should never be reached by Pytensor
566+ return [DisconnectedType ()()] # pragma: no cover
567+
568+ for disconnected , output_grad , output in zip (
569+ is_disconnected , output_grads , [Q , R ], strict = True
570+ ):
571+ if disconnected :
572+ new_output_grads .append (output .zeros_like ())
573+ else :
574+ new_output_grads .append (output_grad )
575+
576+ (dQ , dR ) = (cast (ptb .TensorVariable , x ) for x in new_output_grads )
577+
578+ Rinvt = _H (inv (R ))
579+ Qt_dQ = _H (Q ) @ dQ
580+ R_dRt = R @ _H (dR )
581+ A_bar = (
582+ Q @ (ptb .tril (R_dRt - _H (R_dRt ), k = - 1 ) - _copyutl (Qt_dQ )) + dQ
583+ ) @ Rinvt + Q @ dR
584+
585+ return [A_bar ]
586+
515587
516588def qr (a , mode = "reduced" ):
517589 """
0 commit comments