@@ -443,7 +443,7 @@ The following is an example that distributes dot products across additions.
443
443
.. code ::
444
444
445
445
import pytensor
446
- import pytensor.tensor as at
446
+ import pytensor.tensor as pt
447
447
from pytensor.graph.rewriting.kanren import KanrenRelationSub
448
448
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
449
449
from pytensor.graph.rewriting.utils import rewrite_graph
@@ -462,7 +462,7 @@ The following is an example that distributes dot products across additions.
462
462
)
463
463
464
464
# Tell `kanren` that `add` is associative
465
- fact(associative, at .add)
465
+ fact(associative, pt .add)
466
466
467
467
468
468
def dot_distributeo(in_lv, out_lv):
@@ -473,13 +473,13 @@ The following is an example that distributes dot products across additions.
473
473
# Make sure the input is a `_dot`
474
474
eq(in_lv, etuple(_dot, A_lv, add_term_lv)),
475
475
# Make sure the term being `_dot`ed is an `add`
476
- heado(at .add, add_term_lv),
476
+ heado(pt .add, add_term_lv),
477
477
# Flatten the associative pairings of `add` operations
478
478
assoc_flatten(add_term_lv, add_flat_lv),
479
479
# Get the flattened `add` arguments
480
480
tailo(add_cdr_lv, add_flat_lv),
481
481
# Add all the `_dot`ed arguments and set the output
482
- conso(at .add, dot_cdr_lv, out_lv),
482
+ conso(pt .add, dot_cdr_lv, out_lv),
483
483
# Apply the `_dot` to all the flattened `add` arguments
484
484
mapo(lambda x, y: conso(_dot, etuple(A_lv, x), y), add_cdr_lv, dot_cdr_lv),
485
485
)
@@ -490,10 +490,10 @@ The following is an example that distributes dot products across additions.
490
490
491
491
Below, we apply `dot_distribute_rewrite ` to a few example graphs. First we create simple test graph:
492
492
493
- >>> x_at = at .vector(" x" )
494
- >>> y_at = at .vector(" y" )
495
- >>> A_at = at .matrix(" A" )
496
- >>> test_at = A_at .dot(x_at + y_at)
493
+ >>> x_at = pt .vector(" x" )
494
+ >>> y_at = pt .vector(" y" )
495
+ >>> A_at = pt .matrix(" A" )
496
+ >>> test_at = A_pt .dot(x_at + y_at)
497
497
>>> print (pytensor.pprint(test_at))
498
498
(A @ (x + y))
499
499
@@ -506,18 +506,18 @@ Next we apply the rewrite to the graph:
506
506
We see that the dot product has been distributed, as desired. Now, let's try a
507
507
few more test cases:
508
508
509
- >>> z_at = at .vector(" z" )
510
- >>> w_at = at .vector(" w" )
511
- >>> test_at = A_at .dot((x_at + y_at) + (z_at + w_at))
509
+ >>> z_at = pt .vector(" z" )
510
+ >>> w_at = pt .vector(" w" )
511
+ >>> test_at = A_pt .dot((x_at + y_at) + (z_at + w_at))
512
512
>>> print (pytensor.pprint(test_at))
513
513
(A @ ((x + y) + (z + w)))
514
514
>>> res = rewrite_graph(test_at, include = [], custom_rewrite = dot_distribute_rewrite, clone = False )
515
515
>>> print (pytensor.pprint(res))
516
516
(((A @ x) + (A @ y)) + ((A @ z) + (A @ w)))
517
517
518
- >>> B_at = at .matrix(" B" )
519
- >>> w_at = at .vector(" w" )
520
- >>> test_at = A_at .dot(x_at + (y_at + B_at .dot(z_at + w_at)))
518
+ >>> B_at = pt .matrix(" B" )
519
+ >>> w_at = pt .vector(" w" )
520
+ >>> test_at = A_pt .dot(x_at + (y_at + B_pt .dot(z_at + w_at)))
521
521
>>> print (pytensor.pprint(test_at))
522
522
(A @ (x + (y + ((B @ z) + (B @ w)))))
523
523
>>> res = rewrite_graph(test_at, include = [], custom_rewrite = dot_distribute_rewrite, clone = False )
0 commit comments