|
24 | 24 | Solve, |
25 | 25 | SolveTriangular, |
26 | 26 | ) |
| 27 | +from pytensor.tensor.type import complex_dtypes |
| 28 | + |
| 29 | + |
| 30 | +_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( |
| 31 | + "Complex dtype for {op} not supported in numba mode. " |
| 32 | + "If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor" |
| 33 | +) |
27 | 34 |
|
28 | 35 |
|
29 | 36 | @numba_basic.numba_njit(inline="always") |
@@ -199,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): |
199 | 206 | b_ndim = op.b_ndim |
200 | 207 |
|
201 | 208 | dtype = node.inputs[0].dtype |
202 | | - if str(dtype).startswith("complex"): |
| 209 | + if dtype in complex_dtypes: |
203 | 210 | raise NotImplementedError( |
204 | | - "Complex inputs not currently supported by solve_triangular in Numba mode" |
| 211 | + _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") |
205 | 212 | ) |
206 | 213 |
|
207 | 214 | @numba_basic.numba_njit(inline="always") |
@@ -299,10 +306,8 @@ def numba_funcify_Cholesky(op, node, **kwargs): |
299 | 306 | on_error = op.on_error |
300 | 307 |
|
301 | 308 | dtype = node.inputs[0].dtype |
302 | | - if str(dtype).startswith("complex"): |
303 | | - raise NotImplementedError( |
304 | | - "Complex inputs not currently supported by cholesky in Numba mode" |
305 | | - ) |
| 309 | + if dtype in complex_dtypes: |
| 310 | + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) |
306 | 311 |
|
307 | 312 | @numba_basic.numba_njit(inline="always") |
308 | 313 | def nb_cholesky(a): |
@@ -1089,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs): |
1089 | 1094 | transposed = False # TODO: Solve doesnt currently allow the transposed argument |
1090 | 1095 |
|
1091 | 1096 | dtype = node.inputs[0].dtype |
1092 | | - if str(dtype).startswith("complex"): |
1093 | | - raise NotImplementedError( |
1094 | | - "Complex inputs not currently supported by solve in Numba mode" |
1095 | | - ) |
| 1097 | + if dtype in complex_dtypes: |
| 1098 | + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) |
1096 | 1099 |
|
1097 | 1100 | if assume_a == "gen": |
1098 | 1101 | solve_fn = _solve_gen |
@@ -1206,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): |
1206 | 1209 | check_finite = op.check_finite |
1207 | 1210 |
|
1208 | 1211 | dtype = node.inputs[0].dtype |
1209 | | - if str(dtype).startswith("complex"): |
1210 | | - raise NotImplementedError( |
1211 | | - "Complex inputs not currently supported by cho_solve in Numba mode" |
1212 | | - ) |
| 1212 | + if dtype in complex_dtypes: |
| 1213 | + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) |
1213 | 1214 |
|
1214 | 1215 | @numba_basic.numba_njit(inline="always") |
1215 | 1216 | def cho_solve(c, b): |
|
0 commit comments