Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit c1ac578

Browse files
Make transform tests less specific to input shapes
1 parent f507b87 commit c1ac578

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/test_transforms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,17 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
190190
a_trans_op = _default_transformed_rv(a.owner.op, a.owner).op
191191
transform = a_trans_op.transform
192192

193+
# Remove the static shape assumptions from the value variable so that it's
194+
# easier to construct the numerical Jacobian reference values in higher
195+
# dimensions
196+
a_value_var_gen = at.tensor(
197+
dtype=a_value_var.type.dtype, shape=(None,) * a_value_var.type.ndim
198+
)
193199
a_forward_fn = aesara.function(
194-
[a_value_var], transform.forward(a_value_var, *a.owner.inputs)
200+
[a_value_var_gen], transform.forward(a_value_var_gen, *a.owner.inputs)
195201
)
196202
a_backward_fn = aesara.function(
197-
[a_value_var], transform.backward(a_value_var, *a.owner.inputs)
203+
[a_value_var_gen], transform.backward(a_value_var_gen, *a.owner.inputs)
198204
)
199205
log_jac_fn = aesara.function(
200206
[a_value_var],

0 commit comments

Comments
 (0)