Skip to content

Commit ac48c11

Browse files
author
Etienne Duchesne
committed
qr decompostion gradient: replace swtich by ifelse
1 parent 9e5e765 commit ac48c11

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.gradient import DisconnectedType
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.op import Op
14+
from pytensor.ifelse import ifelse
1415
from pytensor.npy_2_compat import normalize_axis_tuple
1516
from pytensor.raise_op import Assert
1617
from pytensor.tensor import TensorLike
@@ -586,7 +587,7 @@ def _copyltu(x: ptb.TensorVariable):
586587
Y_bar = Q @ dV
587588
A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1)
588589

589-
return [pt.switch(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
590+
return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)]
590591

591592

592593
def qr(a, mode="reduced"):

0 commit comments

Comments
 (0)