@@ -126,6 +126,9 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
126
126
dtype = A .dtype
127
127
w_type = _get_underlying_float (dtype )
128
128
numba_trtrs = _LAPACK ().numba_xtrtrs (dtype )
129
+ if isinstance (dtype , types .Complex ):
130
+ # If you want to make this work with complex numbers make sure you handle the c_contiguous trick correctly
131
+ raise TypeError ("This function is not expected to work with complex numbers" )
129
132
130
133
def impl (A , B , trans , lower , unit_diagonal , b_ndim , overwrite_b ):
131
134
_N = np .int32 (A .shape [- 1 ])
@@ -135,8 +138,15 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
135
138
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
136
139
B_is_1d = B .ndim == 1
137
140
138
- # This will only copy if A is not already fortran contiguous
139
- A_f = np .asfortranarray (A )
141
+ if A .flags .f_contiguous or (A .flags .c_contiguous and trans in (0 , 1 )):
142
+ A_f = A
143
+ if A .flags .c_contiguous :
144
+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
145
+ # Is this valid for complex matrices that were .conj().mT by PyTensor?
146
+ lower = not lower
147
+ trans = 1 - trans
148
+ else :
149
+ A_f = np .asfortranarray (A )
140
150
141
151
if overwrite_b and B .flags .f_contiguous :
142
152
B_copy = B
@@ -633,6 +643,11 @@ def impl(
633
643
_N = np .int32 (A .shape [- 1 ])
634
644
_solve_check_input_shapes (A , B )
635
645
646
+ if overwrite_a and A .flags .c_contiguous :
647
+ # Work with the transposed system to avoid copying A
648
+ A = A .T
649
+ transposed = not transposed
650
+
636
651
order = "I" if transposed else "1"
637
652
norm = _xlange (A , order = order )
638
653
@@ -682,8 +697,11 @@ def impl(
682
697
_LDA , _N = np .int32 (A .shape [- 2 :]) # type: ignore
683
698
_solve_check_input_shapes (A , B )
684
699
685
- if overwrite_a and A .flags .f_contiguous :
700
+ if overwrite_a and ( A .flags .f_contiguous or A . flags . c_contiguous ) :
686
701
A_copy = A
702
+ if A .flags .c_contiguous :
703
+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
704
+ lower = not lower
687
705
else :
688
706
A_copy = _copy_to_fortran_order (A )
689
707
@@ -905,8 +923,11 @@ def impl(
905
923
906
924
_N = np .int32 (A .shape [- 1 ])
907
925
908
- if overwrite_a and A .flags .f_contiguous :
926
+ if overwrite_a and ( A .flags .f_contiguous or A . flags . c_contiguous ) :
909
927
A_copy = A
928
+ if A .flags .c_contiguous :
929
+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
930
+ lower = not lower
910
931
else :
911
932
A_copy = _copy_to_fortran_order (A )
912
933
@@ -1128,7 +1149,13 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
1128
1149
_solve_check_input_shapes (C , B )
1129
1150
1130
1151
_N = np .int32 (C .shape [- 1 ])
1131
- C_f = np .asfortranarray (C )
1152
+ if C .flags .f_contiguous or C .flags .c_contiguous :
1153
+ C_f = C
1154
+ if C .flags .c_contiguous :
1155
+ # An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
1156
+ lower = not lower
1157
+ else :
1158
+ C_f = np .asfortranarray (C )
1132
1159
1133
1160
if overwrite_b and B .flags .f_contiguous :
1134
1161
B_copy = B
0 commit comments