99from pytensor .compile .mode import Mode
1010from pytensor .gradient import DisconnectedType
1111from pytensor .graph import Apply , Op , Variable
12- from pytensor .tensor .basic import infer_static_shape
12+ from pytensor .tensor .basic import as_tensor , infer_static_shape
1313from pytensor .tensor .type import TensorType
1414
1515
@@ -384,7 +384,7 @@ def _find_output_types(
384384 try :
385385 shape_evaluation_function = function (
386386 [],
387- resolved_input_shapes ,
387+ [ as_tensor ( s , dtype = "int64" ) for s in resolved_input_shapes ] ,
388388 on_unused_input = "ignore" ,
389389 mode = Mode (linker = "py" , optimizer = "fast_compile" ),
390390 )
@@ -394,7 +394,7 @@ def _find_output_types(
394394 "Please provide inputs with fully determined shapes by "
395395 "calling pt.specify_shape."
396396 ) from e
397- resolved_input_shapes = shape_evaluation_function ()
397+ resolved_input_shapes = [ tuple ( s ) for s in shape_evaluation_function ()]
398398
399399 # Determine output types using jax.eval_shape with dummy inputs
400400 output_metadata_storage = {}
@@ -422,6 +422,7 @@ def wrapped_jax_function(input_arrays):
422422 output_static = output_metadata_storage ["output_static" ]
423423
424424 # If we used shape evaluation, set all output shapes to unknown
425+ # TODO: This is throwing away potential static shape information.
425426 if requires_shape_evaluation :
426427 output_types = [
427428 TensorType (
0 commit comments