1111from pytensor .graph .basic import Apply
1212from pytensor .graph .op import Op
1313from pytensor .npy_2_compat import normalize_axis_tuple
14+ from pytensor .raise_op import Assert
1415from pytensor .tensor import TensorLike
1516from pytensor .tensor import basic as ptb
1617from pytensor .tensor import math as ptm
@@ -530,26 +531,40 @@ def L_op(self, inputs, outputs, output_grads):
530531 from pytensor .tensor .slinalg import solve_triangular
531532
532533 (A ,) = (cast (ptb .TensorVariable , x ) for x in inputs )
533- * _ , m , n = A .type .shape
534534
535535 def _H (x : ptb .TensorVariable ):
536536 return x .conj ().mT
537537
538538 def _copyutl (x : ptb .TensorVariable ):
539539 return ptb .triu (x , k = 0 ) + _H (ptb .triu (x , k = 1 ))
540540
541- if self .mode == "raw" or ( self . mode == "complete" and m != n ) :
542- raise NotImplementedError ("Gradient of qr not implemented" )
541+ if self .mode == "raw" :
542+ raise NotImplementedError ("Gradient of qr not implemented for mode=raw " )
543543
544- elif m < n :
545- raise NotImplementedError (
546- "Gradient of qr not implemented for m x n matrices with m < n"
544+ elif self .mode == "complete" :
545+ Q , R = (cast (ptb .TensorVariable , x ) for x in outputs )
546+ qr_assert_op = Assert (
547+ "Gradient of qr not implemented for m x n matrices with m != n and mode=complete"
547548 )
549+ R = qr_assert_op (R , ptm .eq (R .shape [0 ], R .shape [1 ]))
548550
549551 elif self .mode == "r" :
552+ qr_assert_op = Assert (
553+ "Gradient of qr not implemented for m x n matrices with m < n and mode=r"
554+ )
555+ A = qr_assert_op (A , ptm .ge (A .shape [0 ], A .shape [1 ]))
550556 # We need all the components of the QR to compute the gradient of A even if we only
551557 # use the upper triangular component in the cost function.
552558 Q , R = qr (A , mode = "reduced" )
559+
560+ else :
561+ Q , R = (cast (ptb .TensorVariable , x ) for x in outputs )
562+ qr_assert_op = Assert (
563+ "Gradient of qr not implemented for m x n matrices with m < n and mode=reduced"
564+ )
565+ R = qr_assert_op (R , ptm .eq (R .shape [0 ], R .shape [1 ]))
566+
567+ if self .mode == "r" :
553568 dR = cast (ptb .TensorVariable , output_grads [0 ])
554569 R_dRt = R @ _H (dR )
555570 M = ptb .tril (R_dRt - _H (R_dRt ), k = - 1 )
@@ -558,8 +573,6 @@ def _copyutl(x: ptb.TensorVariable):
558573 return [A_bar ]
559574
560575 else :
561- Q , R = (cast (ptb .TensorVariable , x ) for x in outputs )
562-
563576 new_output_grads = []
564577 is_disconnected = [
565578 isinstance (x .type , DisconnectedType ) for x in output_grads
0 commit comments