Skip to content

Commit 81b7a1e

Browse files
committed
Use static shape in join_nonshared_inputs
1 parent 5d51953 commit 81b7a1e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc/pytensorf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,13 +596,13 @@ def join_nonshared_inputs(
596596
raise ValueError("Empty list of input variables.")
597597

598598
raveled_inputs = pt.concatenate([var.ravel() for var in inputs])
599+
size = sum(point[var_name].size for var_name in point)
599600

600601
if not make_inputs_shared:
601-
tensor_type = raveled_inputs.type
602-
joined_inputs = tensor_type("joined_inputs")
602+
joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype)
603603
else:
604604
joined_values = np.concatenate([point[var.name].ravel() for var in inputs])
605-
joined_inputs = pytensor.shared(joined_values, "joined_inputs")
605+
joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,))
606606

607607
if pytensor.config.compute_test_value != "off":
608608
joined_inputs.tag.test_value = raveled_inputs.tag.test_value

0 commit comments

Comments
 (0)