-
-
Notifications
You must be signed in to change notification settings - Fork 20
The compiled logprob graphs do not enforce the input shapes #203
Description
Let's consider the following artificial example. Although the variable being conditioned on is of shape (3,), the compiled logprob graph accepts any input shape as long as the input is 1-dimensional:
import aesara
import aeppl
import aesara.tensor as at
srng = at.random.RandomStream(0)
x_rv = srng.normal(0, 1, size=(10,))
y = x_rv[[1,2,5]]
y.name="y"
logprob, (y_vv,) = aeppl.joint_logprob(y)
aesara.dprint(logprob)
# Sum{acc_dtype=float64} [id A]
# |Check{sigma > 0} [id B] 'y_logprob'
# |Elemwise{sub,no_inplace} [id C]
# | |Elemwise{sub,no_inplace} [id D]
# | | |Elemwise{mul,no_inplace} [id E]
# | | | |InplaceDimShuffle{x} [id F]
# | | | | |TensorConstant{-0.5} [id G]
# | | | |Elemwise{pow,no_inplace} [id H]
# | | | |Elemwise{true_div,no_inplace} [id I]
# | | | | |Elemwise{sub,no_inplace} [id J]
# | | | | | |y_vv [id K]
# | | | | | |InplaceDimShuffle{x} [id L]
# | | | | | |TensorConstant{0} [id M]
# | | | | |InplaceDimShuffle{x} [id N]
# | | | | |TensorConstant{1} [id O]
# | | | |InplaceDimShuffle{x} [id P]
# | | | |TensorConstant{2} [id Q]
# | | |InplaceDimShuffle{x} [id R]
# | | |Elemwise{log,no_inplace} [id S]
# | | |Elemwise{sqrt,no_inplace} [id T]
# | | |TensorConstant{6.283185307179586} [id U]
# | |InplaceDimShuffle{x} [id V]
# | |Elemwise{log,no_inplace} [id W]
# | |TensorConstant{1} [id O]
# |All [id X]
# |Elemwise{gt,no_inplace} [id Y]
# |TensorConstant{1} [id O]
# |TensorConstant{0.0} [id Z]
logprob_fn = aesara.function([y_vv], logprob)
try:
print(logprob_fn(1.))
except Exception as e:
print(e)
# Wrong number of dimensions: expected 1, got 0 with shape ().
print(logprob_fn([1.]))
# -1.4189385332046727
print(logprob_fn([1., 1., 1.]))
# -4.2568155996140185I would however expect the compiled function to behave like the following:
y_at = at.tensor(shape=(3,), dtype="float32")
out = at.sum(y_at)
fn = aesara.function((y_at,), out)
try:
print(fn([1.]))
except Exception as e:
print(e)
# The type's shape ((3,)) is not compatible with the data's ((1,))This was already noted here #51 (reply in thread), and is due to the fact that we do not enforce any type of shape constraints when cloning the random variables to get the value variables:
Line 100 in eb55106
| vv = rv.clone() |
This constraint would have been difficult to enforce with the previous conditional_logprob interface that asked users to pass value variables. However, we now have full control over how value variables are created, and could thus specify their shape and dtype.