@@ -126,6 +126,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
126126 dtype = A .dtype
127127 w_type = _get_underlying_float (dtype )
128128 numba_trtrs = _LAPACK ().numba_xtrtrs (dtype )
129+ if isinstance (dtype , types .Complex ):
130+ # If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
131+ raise TypeError ("This function is not expected to work with complex numbers" )
129132
130133 def impl (A , B , trans , lower , unit_diagonal , b_ndim , overwrite_b ):
131134 _N = np .int32 (A .shape [- 1 ])
@@ -135,8 +138,15 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
135138 # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
136139 B_is_1d = B .ndim == 1
137140
138- # This will only copy if A is not already fortran contiguous
139- A_f = np .asfortranarray (A )
141+ if A .flags .f_contiguous or (A .flags .c_contiguous and trans in (0 , 1 )):
142+ A_f = A
143+ if A .flags .c_contiguous :
144+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
145+ # Is this valid for complex matrices that were .conj().mT by PyTensor?
146+ lower = not lower
147+ trans = 1 - trans
148+ else :
149+ A_f = np .asfortranarray (A )
140150
141151 if overwrite_b and B .flags .f_contiguous :
142152 B_copy = B
@@ -633,6 +643,11 @@ def impl(
633643 _N = np .int32 (A .shape [- 1 ])
634644 _solve_check_input_shapes (A , B )
635645
646+ if overwrite_a and A .flags .c_contiguous :
647+ # Work with the transposed system to avoid copying A
648+ A = A .T
649+ transposed = not transposed
650+
636651 order = "I" if transposed else "1"
637652 norm = _xlange (A , order = order )
638653
@@ -682,8 +697,11 @@ def impl(
682697 _LDA , _N = np .int32 (A .shape [- 2 :]) # type: ignore
683698 _solve_check_input_shapes (A , B )
684699
685- if overwrite_a and A .flags .f_contiguous :
700+ if overwrite_a and ( A .flags .f_contiguous or A . flags . c_contiguous ) :
686701 A_copy = A
702+ if A .flags .c_contiguous :
703+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
704+ lower = not lower
687705 else :
688706 A_copy = _copy_to_fortran_order (A )
689707
@@ -905,8 +923,11 @@ def impl(
905923
906924 _N = np .int32 (A .shape [- 1 ])
907925
908- if overwrite_a and A .flags .f_contiguous :
926+ if overwrite_a and ( A .flags .f_contiguous or A . flags . c_contiguous ) :
909927 A_copy = A
928+ if A .flags .c_contiguous :
929+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
930+ lower = not lower
910931 else :
911932 A_copy = _copy_to_fortran_order (A )
912933
@@ -1128,7 +1149,13 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11281149 _solve_check_input_shapes (C , B )
11291150
11301151 _N = np .int32 (C .shape [- 1 ])
1131- C_f = np .asfortranarray (C )
1152+ if C .flags .f_contiguous or C .flags .c_contiguous :
1153+ C_f = C
1154+ if C .flags .c_contiguous :
1155+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
1156+ lower = not lower
1157+ else :
1158+ C_f = np .asfortranarray (C )
11321159
11331160 if overwrite_b and B .flags .f_contiguous :
11341161 B_copy = B
0 commit comments