Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

The compiled logprob graphs do not enforce the input shapes #203

@rlouf

Description

@rlouf

Let's consider the following artificial example. Although the variable being conditioned on is of shape (3,), the compiled logprob graph accepts any input shape as long as the input is 1-dimensional:

import aesara
import aeppl
import aesara.tensor as at

srng = at.random.RandomStream(0)
x_rv = srng.normal(0, 1, size=(10,))
y = x_rv[[1,2,5]]
y.name="y"

logprob, (y_vv,) = aeppl.joint_logprob(y)
aesara.dprint(logprob)
# Sum{acc_dtype=float64} [id A]
#  |Check{sigma > 0} [id B] 'y_logprob'
#    |Elemwise{sub,no_inplace} [id C]
#    | |Elemwise{sub,no_inplace} [id D]
#    | | |Elemwise{mul,no_inplace} [id E]
#    | | | |InplaceDimShuffle{x} [id F]
#    | | | | |TensorConstant{-0.5} [id G]
#    | | | |Elemwise{pow,no_inplace} [id H]
#    | | |   |Elemwise{true_div,no_inplace} [id I]
#    | | |   | |Elemwise{sub,no_inplace} [id J]
#    | | |   | | |y_vv [id K]
#    | | |   | | |InplaceDimShuffle{x} [id L]
#    | | |   | |   |TensorConstant{0} [id M]
#    | | |   | |InplaceDimShuffle{x} [id N]
#    | | |   |   |TensorConstant{1} [id O]
#    | | |   |InplaceDimShuffle{x} [id P]
#    | | |     |TensorConstant{2} [id Q]
#    | | |InplaceDimShuffle{x} [id R]
#    | |   |Elemwise{log,no_inplace} [id S]
#    | |     |Elemwise{sqrt,no_inplace} [id T]
#    | |       |TensorConstant{6.283185307179586} [id U]
#    | |InplaceDimShuffle{x} [id V]
#    |   |Elemwise{log,no_inplace} [id W]
#    |     |TensorConstant{1} [id O]
#    |All [id X]
#      |Elemwise{gt,no_inplace} [id Y]
#        |TensorConstant{1} [id O]
#        |TensorConstant{0.0} [id Z]

logprob_fn = aesara.function([y_vv], logprob)
try:
    print(logprob_fn(1.))
except Exception as e:
    print(e)
# Wrong number of dimensions: expected 1, got 0 with shape ().

print(logprob_fn([1.]))
# -1.4189385332046727
print(logprob_fn([1., 1., 1.]))
# -4.2568155996140185

I would however expect the compiled function to behave like the following:

y_at = at.tensor(shape=(3,), dtype="float32")
out = at.sum(y_at)
fn = aesara.function((y_at,), out)
try:
    print(fn([1.]))
except Exception as e:
    print(e)
# The type's shape ((3,)) is not compatible with the data's ((1,))

This was already noted here #51 (reply in thread), and is due to the fact that we do not enforce any type of shape constraints when cloning the random variables to get the value variables:

vv = rv.clone()

This constraint would have been difficult to enforce with the previous conditional_logprob interface that asked users to pass value variables. However, we now have full control over how value variables are created, and could thus specify their shape and dtype.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingenhancementNew feature or requestgood first issueGood for newcomershelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions