@@ -136,7 +136,15 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
136136 B_is_1d = B .ndim == 1
137137
138138 # This will only copy if A is not already fortran contiguous
139- A_f = np .asfortranarray (A )
139+ if A .flags .f_contiguous or (A .flags .c_contiguous and trans in (0 , 1 )):
140+ A_f = A
141+ if A .flags .c_contiguous :
142+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
143+ # Is this valid for complex matrices that were .conj().mT by PyTensor?
144+ lower = not lower
145+ trans = 1 - trans
146+ else :
147+ A_f = np .asfortranarray (A )
140148
141149 if overwrite_b and B .flags .f_contiguous :
142150 B_copy = B
@@ -682,8 +690,11 @@ def impl(
682690 _LDA , _N = np .int32 (A .shape [- 2 :]) # type: ignore
683691 _solve_check_input_shapes (A , B )
684692
685- if overwrite_a and A .flags .f_contiguous :
693+ if overwrite_a and ( A .flags .f_contiguous or A . flags . c_contiguous ) :
686694 A_copy = A
695+ if A .flags .c_contiguous :
696+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
697+ lower = not lower
687698 else :
688699 A_copy = _copy_to_fortran_order (A )
689700
@@ -905,8 +916,11 @@ def impl(
905916
906917 _N = np .int32 (A .shape [- 1 ])
907918
908- if overwrite_a and A .flags .f_contiguous :
919+ if overwrite_a and ( A .flags .f_contiguous or A . flags . c_contiguous ) :
909920 A_copy = A
921+ if A .flags .c_contiguous :
922+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
923+ lower = not lower
910924 else :
911925 A_copy = _copy_to_fortran_order (A )
912926
@@ -1128,7 +1142,13 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
11281142 _solve_check_input_shapes (C , B )
11291143
11301144 _N = np .int32 (C .shape [- 1 ])
1131- C_f = np .asfortranarray (C )
1145+ if C .flags .f_contiguous or C .flags .c_contiguous :
1146+ C_f = C
1147+ if C .flags .c_contiguous :
1148+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
1149+ lower = not lower
1150+ else :
1151+ C_f = np .asfortranarray (C )
11321152
11331153 if overwrite_b and B .flags .f_contiguous :
11341154 B_copy = B
0 commit comments