|
23 | 23 | """ |
24 | 24 |
|
25 | 25 | import logging |
| 26 | +import warnings |
26 | 27 |
|
27 | 28 | import numpy as np |
28 | 29 |
|
29 | 30 | import pytensor.scalar.basic as ps |
30 | 31 | from pytensor import compile, config |
31 | 32 | from pytensor.compile.ops import ViewOp |
32 | | -from pytensor.graph import FunctionGraph |
33 | | -from pytensor.graph.basic import Constant |
| 33 | +from pytensor.graph import FunctionGraph, rewrite_graph |
| 34 | +from pytensor.graph.basic import Apply, Constant |
34 | 35 | from pytensor.graph.rewriting.basic import ( |
35 | 36 | NodeProcessingGraphRewriter, |
36 | 37 | NodeRewriter, |
|
39 | 40 | copy_stack_trace, |
40 | 41 | in2out, |
41 | 42 | node_rewriter, |
| 43 | + out2in, |
42 | 44 | ) |
43 | 45 | from pytensor.graph.rewriting.db import RewriteDatabase |
| 46 | +from pytensor.link.c.op import COp |
44 | 47 | from pytensor.raise_op import Assert, CheckAndRaise, assert_op |
45 | 48 | from pytensor.scalar.basic import Second |
46 | 49 | from pytensor.tensor.basic import ( |
|
55 | 58 | as_tensor_variable, |
56 | 59 | atleast_Nd, |
57 | 60 | cast, |
| 61 | + empty, |
58 | 62 | fill, |
59 | 63 | get_scalar_constant_value, |
60 | 64 | join, |
|
70 | 74 | from pytensor.tensor.extra_ops import broadcast_arrays |
71 | 75 | from pytensor.tensor.math import Sum, add, eq, variadic_add |
72 | 76 | from pytensor.tensor.shape import Shape_i, shape_padleft |
| 77 | +from pytensor.tensor.subtensor import ( |
| 78 | + Subtensor, |
| 79 | +) |
73 | 80 | from pytensor.tensor.type import DenseTensorType, TensorType |
74 | 81 | from pytensor.tensor.variable import TensorConstant, TensorVariable |
75 | 82 | from pytensor.utils import NoDuplicateOptWarningFilter |
@@ -1356,3 +1363,144 @@ def local_join_of_alloc(fgraph, node): |
1356 | 1363 | new_out = alloc(new_join, *post_join_shape) |
1357 | 1364 | copy_stack_trace(node.outputs[0], new_out) |
1358 | 1365 | return [new_out] |
| 1366 | + |
| 1367 | + |
| 1368 | +class BufferSplit(Subtensor): |
| 1369 | + view_map = {} # It' a lie so PyTensor doesn't complain we are mutating the same input in parallel |
| 1370 | + |
| 1371 | + def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: |
| 1372 | + return False |
| 1373 | + |
| 1374 | + |
| 1375 | +class BufferJoin(COp): |
| 1376 | + """ |
| 1377 | + Returns an inplace view of the input. Used internally by PyTensor. |
| 1378 | +
|
| 1379 | + """ |
| 1380 | + |
| 1381 | + # view_map = {0: [0]} |
| 1382 | + # Mapping from Type to C code (and version) to use. |
| 1383 | + # In the C code, the name of the input variable is %(iname)s, |
| 1384 | + # the output variable is %(oname)s. |
| 1385 | + c_code_and_version: dict = {} |
| 1386 | + __props__: tuple = () |
| 1387 | + _f16_ok: bool = True |
| 1388 | + destroy_map = {0: [0]} |
| 1389 | + |
| 1390 | + def make_node(self, buffer_source, *buffer_updates): |
| 1391 | + out = buffer_source.type() |
| 1392 | + return Apply(self, [buffer_source, *buffer_updates], [out]) |
| 1393 | + |
| 1394 | + def perform(self, node, inputs, output_storage): |
| 1395 | + output_storage[0][0] = inputs[0] |
| 1396 | + |
| 1397 | + def c_code(self, node, nodename, inp, out, sub): |
| 1398 | + iname, *_ = inp |
| 1399 | + [oname] = out |
| 1400 | + fail = sub["fail"] |
| 1401 | + |
| 1402 | + itype = node.inputs[0].type.__class__ |
| 1403 | + if itype in self.c_code_and_version: |
| 1404 | + code, version = self.c_code_and_version[itype] |
| 1405 | + return code % locals() |
| 1406 | + |
| 1407 | + # Else, no C code |
| 1408 | + raise NotImplementedError() |
| 1409 | + |
| 1410 | + def c_code_cache_version(self): |
| 1411 | + version = [] |
| 1412 | + # If any of the c code is unversioned, we have to return () |
| 1413 | + # Else, we will return a list of (type name, version) pairs. |
| 1414 | + for t, (c, v) in sorted( |
| 1415 | + self.c_code_and_version.items(), key=lambda pair: str(pair[0]) |
| 1416 | + ): |
| 1417 | + if not v: |
| 1418 | + warnings.warn( |
| 1419 | + f"Type {t} has C code for ViewOp, but it has no " |
| 1420 | + "version. You should add a 'version' keyword " |
| 1421 | + "arg when calling register_view_op_c_code.", |
| 1422 | + stacklevel=2, |
| 1423 | + ) |
| 1424 | + return () |
| 1425 | + version.append((str(t), v)) |
| 1426 | + |
| 1427 | + return tuple(version) |
| 1428 | + |
| 1429 | + def infer_shape(self, fgraph, node, input_shapes): |
| 1430 | + return [input_shapes[0]] |
| 1431 | + |
| 1432 | + |
| 1433 | +buffer_join = BufferJoin() |
| 1434 | + |
| 1435 | + |
| 1436 | +@node_rewriter([Join], inplace=True) |
| 1437 | +def inplace_join(fgraph, node): |
| 1438 | + axis, *tensors = node.inputs |
| 1439 | + |
| 1440 | + if len(tensors) == 1: |
| 1441 | + return tensors |
| 1442 | + |
| 1443 | + if not isinstance(axis, Constant): |
| 1444 | + return |
| 1445 | + |
| 1446 | + shape_feature = getattr(fgraph, "shape_feature", None) |
| 1447 | + if shape_feature is None: |
| 1448 | + return |
| 1449 | + |
| 1450 | + static_axis = int(axis.data) |
| 1451 | + |
| 1452 | + [out] = node.outputs |
| 1453 | + out_shape = shape_feature.shape_of[out] |
| 1454 | + buffer = empty(out_shape, dtype=out.dtype) |
| 1455 | + |
| 1456 | + empty_slices = (slice(None),) * static_axis |
| 1457 | + prev_start = None |
| 1458 | + buffer_updates = [] |
| 1459 | + for i, y in enumerate(tensors): |
| 1460 | + if not (y.owner is not None and isinstance(y.owner.op, Elemwise)): |
| 1461 | + # We only know how to inplace Elemwise |
| 1462 | + return None |
| 1463 | + |
| 1464 | + if prev_start is None: |
| 1465 | + end = shape_feature.shape_of[y][static_axis] |
| 1466 | + elif i == (len(tensors) - 1): |
| 1467 | + end = None |
| 1468 | + else: |
| 1469 | + end = prev_start + shape_feature.shape_of[y][static_axis] |
| 1470 | + tmp_subtensor = buffer[(*empty_slices, slice(prev_start, end))] |
| 1471 | + buffer_view = BufferSplit(tmp_subtensor.owner.op.idx_list)( |
| 1472 | + *tmp_subtensor.owner.inputs |
| 1473 | + ) |
| 1474 | + prev_start = end |
| 1475 | + |
| 1476 | + from pytensor.tensor.rewriting.elemwise import FusionOptimizer |
| 1477 | + |
| 1478 | + scalar_inputs, scalar_outputs = FusionOptimizer.elemwise_to_scalar( |
| 1479 | + (*y.owner.inputs, buffer_view), y.owner.outputs |
| 1480 | + ) |
| 1481 | + |
| 1482 | + # Set y to override the buffer |
| 1483 | + inplace_pattern = dict(y.owner.op.inplace_pattern) |
| 1484 | + y_idx = y.owner.outputs.index(y) |
| 1485 | + inplace_pattern[y_idx] = len(y.owner.inputs) |
| 1486 | + |
| 1487 | + new_op = Elemwise( |
| 1488 | + ps.Composite(scalar_inputs, scalar_outputs), inplace_pattern=inplace_pattern |
| 1489 | + ) |
| 1490 | + buffer_update = new_op(*y.owner.inputs, buffer_view, return_list=True)[y_idx] |
| 1491 | + buffer_updates.append(buffer_update) |
| 1492 | + |
| 1493 | + out = [buffer_join(buffer, *buffer_updates)] |
| 1494 | + out = rewrite_graph( |
| 1495 | + out, include=("canonicalize",), exclude=("local_useless_composite_outputs",) |
| 1496 | + ) |
| 1497 | + return out |
| 1498 | + |
| 1499 | + |
| 1500 | +compile.optdb.register( |
| 1501 | + inplace_join.__name__, |
| 1502 | + out2in(inplace_join), |
| 1503 | + "fast_run", |
| 1504 | + "inplace", |
| 1505 | + position=50.51, # After the fusion inplace |
| 1506 | +) |
0 commit comments