Skip to content

Commit cc1a7d4

Browse files
committed
Fix wrap_jax when there is a mix of statically known and unknown shapes
1 parent c77d1ef commit cc1a7d4

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

pytensor/link/jax/ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.gradient import DisconnectedType
1111
from pytensor.graph import Apply, Op, Variable
12-
from pytensor.tensor.basic import infer_static_shape
12+
from pytensor.tensor.basic import as_tensor, infer_static_shape
1313
from pytensor.tensor.type import TensorType
1414

1515

@@ -384,7 +384,7 @@ def _find_output_types(
384384
try:
385385
shape_evaluation_function = function(
386386
[],
387-
resolved_input_shapes,
387+
[as_tensor(s, dtype="int64") for s in resolved_input_shapes],
388388
on_unused_input="ignore",
389389
mode=Mode(linker="py", optimizer="fast_compile"),
390390
)
@@ -394,7 +394,7 @@ def _find_output_types(
394394
"Please provide inputs with fully determined shapes by "
395395
"calling pt.specify_shape."
396396
) from e
397-
resolved_input_shapes = shape_evaluation_function()
397+
resolved_input_shapes = [tuple(s) for s in shape_evaluation_function()]
398398

399399
# Determine output types using jax.eval_shape with dummy inputs
400400
output_metadata_storage = {}
@@ -422,6 +422,7 @@ def wrapped_jax_function(input_arrays):
422422
output_static = output_metadata_storage["output_static"]
423423

424424
# If we used shape evaluation, set all output shapes to unknown
425+
# TODO: This is throwing away potential static shape information.
425426
if requires_shape_evaluation:
426427
output_types = [
427428
TensorType(

tests/link/jax/test_wrap_jax.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,15 @@ def f(x, y):
559559
compare_jax_and_py([x, y], [out, *grad_out], test_values)
560560
else:
561561
compare_jax_and_py([x, y], [out], test_values)
562+
563+
564+
def test_mixed_static_shape():
565+
x_unknown = shared(np.ones((3,)))
566+
x_known = shared(np.ones((4,)), shape=(4,))
567+
568+
def f(x1, x2):
569+
return jax.numpy.concatenate([x1, x2])
570+
571+
assert wrap_jax(f)(x_known, x_known).type.shape == (8,)
572+
assert wrap_jax(f)(x_known, x_unknown).type.shape == (None,)
573+
assert wrap_jax(f)(x_unknown, x_known).type.shape == (None,)

0 commit comments

Comments
 (0)