|
12 | 12 | from pytensor.compile.mode import get_default_mode, get_mode
|
13 | 13 | from pytensor.compile.ops import DeepCopyOp, deep_copy_op
|
14 | 14 | from pytensor.configdefaults import config
|
15 |
| -from pytensor.graph.basic import equal_computations |
| 15 | +from pytensor.graph.basic import equal_computations, vars_between |
16 | 16 | from pytensor.graph.fg import FunctionGraph
|
17 | 17 | from pytensor.graph.rewriting.basic import check_stack_trace, out2in
|
18 | 18 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery
|
|
31 | 31 | )
|
32 | 32 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
33 | 33 | from pytensor.tensor.math import (
|
| 34 | + Sum, |
34 | 35 | add,
|
35 | 36 | bitwise_and,
|
36 | 37 | bitwise_or,
|
@@ -1300,6 +1301,44 @@ def test_local_join_make_vector():
|
1300 | 1301 | assert check_stack_trace(f, ops_to_check="all")
|
1301 | 1302 |
|
1302 | 1303 |
|
| 1304 | +def test_local_sum_make_vector(): |
| 1305 | + a, b, c = scalars("abc") |
| 1306 | + mv = MakeVector(config.floatX) |
| 1307 | + output = mv(a, b, c).sum() |
| 1308 | + |
| 1309 | + output = rewrite_graph(output) |
| 1310 | + between = vars_between([a, b, c], [output]) |
| 1311 | + for var in between: |
| 1312 | + assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector)) |
| 1313 | + |
| 1314 | + # Check for empty sum |
| 1315 | + a, b, c = scalars("abc") |
| 1316 | + mv = MakeVector(config.floatX) |
| 1317 | + output = mv(a, b, c).sum(axis=[]) |
| 1318 | + |
| 1319 | + output = rewrite_graph(output) |
| 1320 | + between = vars_between([a, b, c], [output]) |
| 1321 | + for var in between: |
| 1322 | + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
| 1323 | + |
| 1324 | + # Check empty MakeVector |
| 1325 | + mv = MakeVector(config.floatX) |
| 1326 | + output = mv().sum() |
| 1327 | + |
| 1328 | + output = rewrite_graph(output) |
| 1329 | + between = vars_between([a, b, c], [output]) |
| 1330 | + for var in between: |
| 1331 | + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
| 1332 | + |
| 1333 | + mv = MakeVector(config.floatX) |
| 1334 | + output = mv(a).sum() |
| 1335 | + |
| 1336 | + output = rewrite_graph(output) |
| 1337 | + between = vars_between([a, b, c], [output]) |
| 1338 | + for var in between: |
| 1339 | + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
| 1340 | + |
| 1341 | + |
1303 | 1342 | @pytest.mark.parametrize(
|
1304 | 1343 | "dtype",
|
1305 | 1344 | [
|
|
0 commit comments