Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 18, 2024

This can be pretty useful for timeseries models as well as normalizing flows.

Here is a simple example:

import numpy as np
import pytensor.tensor as pt

import pymc as pm

rng = np.random.default_rng(37)
x = pm.MvNormal.dist(cov=np.eye(2), size=(128,))

n_layers = 3
A = pt.tensor("A", shape=(n_layers, 2, 2))
b = pt.tensor("b", shape=(n_layers, 2,))

# Repeated layers of Affine transform -> Tanh
for i in range(n_layers):
    y = A[i] @ x + b[i]
    # parametrized leaky-relu would be nicer: https://github.com/pymc-devs/pymc/issues/7543
    # y = pt.switch(y > 0, y, c[i] * y)
    y = pt.tanh(y)

A_test = rng.normal(size=A.type.shape)
b_test = rng.normal(size=b.type.shape)
y_test = rng.uniform(-1, 1, size=y.type.shape)
pm.logp(y, y_test).sum().eval({A: A_test, b: b_test})  # array(-3.54498234)

@ricardoV94 ricardoV94 requested a review from ferrine October 18, 2024 20:53
Copy link

codecov bot commented Oct 18, 2024

Codecov Report

Attention: Patch coverage is 87.77778% with 11 lines in your changes missing coverage. Please review.

Project coverage is 92.82%. Comparing base (5352798) to head (5e5e077).

Files with missing lines Patch % Lines
pymc/logprob/tensor.py 79.41% 7 Missing ⚠️
pymc/logprob/linalg.py 91.48% 4 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7542      +/-   ##
==========================================
- Coverage   92.85%   92.82%   -0.04%     
==========================================
  Files         105      106       +1     
  Lines       17591    17669      +78     
==========================================
+ Hits        16335    16402      +67     
- Misses       1256     1267      +11     
Files with missing lines Coverage Δ
pymc/logprob/__init__.py 100.00% <100.00%> (ø)
pymc/logprob/abstract.py 94.28% <100.00%> (+0.16%) ⬆️
pymc/logprob/basic.py 94.28% <100.00%> (ø)
pymc/logprob/mixture.py 95.70% <100.00%> (ø)
pymc/logprob/rewriting.py 100.00% <100.00%> (ø)
pymc/logprob/scan.py 94.90% <ø> (ø)
pymc/logprob/transform_value.py 98.14% <100.00%> (ø)
pymc/logprob/utils.py 92.46% <100.00%> (ø)
pymc/logprob/linalg.py 91.48% <91.48%> (ø)
pymc/logprob/tensor.py 94.48% <79.41%> (-5.52%) ⬇️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much more natural order used in all sorts of pytensor utilities that require a node/variable and its' fgraph

Comment on lines +277 to +285
Copy link
Member Author

@ricardoV94 ricardoV94 Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These DimShuffle changes were needed to naturally accommodate A @ x when x is a vector, which looks like:

import pytensor.tensor as pt

A = pt.matrix("A")
x = pt.vector("x")
y = A @ x
y.dprint()
# DropDims{axis=1} [id A]
#  └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id B]
#     ├─ A [id C]
#     └─ ExpandDims{axis=1} [id D]
#        └─ x [id E]

It's also more strict / correct than the limitation we had before, because the concerns are much more about what's after the DimShuffle not so much before.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm : )

@ricardoV94 ricardoV94 merged commit 1249c86 into pymc-devs:main Oct 21, 2024
18 of 20 checks passed
@ricardoV94 ricardoV94 deleted the matmul branch October 21, 2024 09:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants