-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Derive logprob of matmul #7542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive logprob of matmul #7542
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7542 +/- ##
==========================================
- Coverage 92.85% 92.82% -0.04%
==========================================
Files 105 106 +1
Lines 17591 17669 +78
==========================================
+ Hits 16335 16402 +67
- Misses 1256 1267 +11
|
… direct valued nodes
pymc/logprob/utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much more natural order used in all sorts of pytensor utilities that require a node/variable and its' fgraph
pymc/logprob/tensor.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These DimShuffle changes were needed to naturally accommodate A @ x when x is a vector, which looks like:
import pytensor.tensor as pt
A = pt.matrix("A")
x = pt.vector("x")
y = A @ x
y.dprint()
# DropDims{axis=1} [id A]
# └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id B]
# ├─ A [id C]
# └─ ExpandDims{axis=1} [id D]
# └─ x [id E]It's also more strict / correct than the limitation we had before, because the concerns are much more about what's after the DimShuffle not so much before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm : )
This can be pretty useful for timeseries models as well as normalizing flows.
Here is a simple example: