From cc1a7d4ed11d8ae93fb23f39602dd8a584387897 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 16 Nov 2025 14:01:16 +0100 Subject: [PATCH] Fix wrap_jax when there is a mix of statically known and unknown shapes --- pytensor/link/jax/ops.py | 7 ++++--- tests/link/jax/test_wrap_jax.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index cada35afdd..38978eef75 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.gradient import DisconnectedType from pytensor.graph import Apply, Op, Variable -from pytensor.tensor.basic import infer_static_shape +from pytensor.tensor.basic import as_tensor, infer_static_shape from pytensor.tensor.type import TensorType @@ -384,7 +384,7 @@ def _find_output_types( try: shape_evaluation_function = function( [], - resolved_input_shapes, + [as_tensor(s, dtype="int64") for s in resolved_input_shapes], on_unused_input="ignore", mode=Mode(linker="py", optimizer="fast_compile"), ) @@ -394,7 +394,7 @@ def _find_output_types( "Please provide inputs with fully determined shapes by " "calling pt.specify_shape." ) from e - resolved_input_shapes = shape_evaluation_function() + resolved_input_shapes = [tuple(s) for s in shape_evaluation_function()] # Determine output types using jax.eval_shape with dummy inputs output_metadata_storage = {} @@ -422,6 +422,7 @@ def wrapped_jax_function(input_arrays): output_static = output_metadata_storage["output_static"] # If we used shape evaluation, set all output shapes to unknown + # TODO: This is throwing away potential static shape information. if requires_shape_evaluation: output_types = [ TensorType( diff --git a/tests/link/jax/test_wrap_jax.py b/tests/link/jax/test_wrap_jax.py index 2052b5f4db..b328ab22c5 100644 --- a/tests/link/jax/test_wrap_jax.py +++ b/tests/link/jax/test_wrap_jax.py @@ -559,3 +559,15 @@ def f(x, y): compare_jax_and_py([x, y], [out, *grad_out], test_values) else: compare_jax_and_py([x, y], [out], test_values) + + +def test_mixed_static_shape(): + x_unknown = shared(np.ones((3,))) + x_known = shared(np.ones((4,)), shape=(4,)) + + def f(x1, x2): + return jax.numpy.concatenate([x1, x2]) + + assert wrap_jax(f)(x_known, x_known).type.shape == (8,) + assert wrap_jax(f)(x_known, x_unknown).type.shape == (None,) + assert wrap_jax(f)(x_unknown, x_known).type.shape == (None,)