|
3 | 3 | import numpy as np |
4 | 4 |
|
5 | 5 | from pytensor import config |
| 6 | +from pytensor.link.numba.cache import compile_numba_function_src |
6 | 7 | from pytensor.link.numba.dispatch import basic as numba_basic |
7 | 8 | from pytensor.link.numba.dispatch.basic import ( |
8 | 9 | generate_fallback_impl, |
|
30 | 31 | from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric |
31 | 32 | from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular |
32 | 33 | from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal |
| 34 | +from pytensor.link.numba.dispatch.string_codegen import ( |
| 35 | + CODE_TOKEN, |
| 36 | + build_source_code, |
| 37 | +) |
33 | 38 | from pytensor.tensor.slinalg import ( |
34 | 39 | LU, |
35 | 40 | QR, |
@@ -222,24 +227,65 @@ def lu_factor(a): |
222 | 227 |
|
223 | 228 | @register_funcify_default_op_cache_key(BlockDiagonal) |
224 | 229 | def numba_funcify_BlockDiagonal(op, node, **kwargs): |
225 | | - dtype = node.outputs[0].dtype |
| 230 | + """Codegen something like: |
226 | 231 |
|
227 | | - @numba_basic.numba_njit |
228 | | - def block_diag(*arrs): |
229 | | - shapes = np.array([a.shape for a in arrs], dtype="int") |
230 | | - out_shape = [int(s) for s in np.sum(shapes, axis=0)] |
231 | | - out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype) |
| 232 | + def block_diagonal(arr0, arr1, arr2): |
| 233 | + out_r = arr0.shape[0] + arr1.shape[0] + arr2.shape[0] |
| 234 | + out_c = arr0.shape[1] + arr1.shape[1] + arr2.shape[1] |
| 235 | + out = np.zeros((out_r, out_c), dtype=np.float64) |
232 | 236 |
|
233 | 237 | r, c = 0, 0 |
234 | | - # no strict argument because it is incompatible with numba |
235 | | - for arr, shape in zip(arrs, shapes): |
236 | | - rr, cc = shape |
237 | | - out[r : r + rr, c : c + cc] = arr |
238 | | - r += rr |
239 | | - c += cc |
| 238 | + rr, cc = arr0.shape |
| 239 | + out[r: r + rr, c: c + cc] = arr0 |
| 240 | + r += rr |
| 241 | + c += cc |
| 242 | +
|
| 243 | + rr, cc = arr1.shape |
| 244 | + out[r: r + rr, c: c + cc] = arr1 |
| 245 | + r += rr |
| 246 | + c += cc |
| 247 | +
|
| 248 | + rr, cc = arr2.shape |
| 249 | + out[r: r + rr, c: c + cc] = arr2 |
| 250 | + r += rr |
| 251 | + c += cc |
| 252 | +
|
240 | 253 | return out |
| 254 | + """ |
| 255 | + dtype = node.outputs[0].dtype |
| 256 | + n_inp = len(node.inputs) |
| 257 | + |
| 258 | + arg_names = [f"arr{i}" for i in range(n_inp)] |
| 259 | + code = [ |
| 260 | + f"def block_diagonal({', '.join(arg_names)}):", |
| 261 | + CODE_TOKEN.INDENT, |
| 262 | + f"out_r = {' + '.join(f'{a}.shape[0]' for a in arg_names)}", |
| 263 | + f"out_c = {' + '.join(f'{a}.shape[1]' for a in arg_names)}", |
| 264 | + f"out = np.zeros((out_r, out_c), dtype=np.{dtype})", |
| 265 | + CODE_TOKEN.EMPTY_LINE, |
| 266 | + "r, c = 0, 0", |
| 267 | + ] |
| 268 | + for i, arg_name in enumerate(arg_names): |
| 269 | + code.extend( |
| 270 | + [ |
| 271 | + f"rr, cc = {arg_name}.shape", |
| 272 | + f"out[r: r + rr, c: c + cc] = {arg_name}", |
| 273 | + "r += rr", |
| 274 | + "c += cc", |
| 275 | + CODE_TOKEN.EMPTY_LINE, |
| 276 | + ] |
| 277 | + ) |
| 278 | + code.append("return out") |
241 | 279 |
|
242 | | - return block_diag |
| 280 | + code_txt = build_source_code(code) |
| 281 | + block_diag = compile_numba_function_src( |
| 282 | + code_txt, |
| 283 | + "block_diagonal", |
| 284 | + globals() | {"np": np}, |
| 285 | + ) |
| 286 | + |
| 287 | + cache_key = 1 |
| 288 | + return numba_basic.numba_njit(block_diag), cache_key |
243 | 289 |
|
244 | 290 |
|
245 | 291 | @register_funcify_default_op_cache_key(Solve) |
|
0 commit comments