2020from pytensor .tensor .elemwise import get_normalized_batch_axes
2121from pytensor .tensor .exceptions import NotScalarConstantError
2222from pytensor .tensor .type import DenseTensorType , TensorType , int_dtypes , tensor
23- from pytensor .tensor .type_other import NoneConst
23+ from pytensor .tensor .type_other import NoneConst , NoneTypeT
2424from pytensor .tensor .variable import TensorConstant , TensorVariable
2525
2626
@@ -401,8 +401,6 @@ class SpecifyShape(COp):
401401 _output_type_depends_on_input_value = True
402402
403403 def make_node (self , x , * shape ):
404- from pytensor .tensor .basic import get_underlying_scalar_constant_value
405-
406404 x = ptb .as_tensor_variable (x )
407405
408406 shape = tuple (
@@ -428,11 +426,9 @@ def make_node(self, x, *shape):
428426 for i , (xts , s ) in enumerate (zip (x .type .shape , shape , strict = True )):
429427 if xts is not None :
430428 type_shape [i ] = xts
431- else :
429+ elif not isinstance ( s . type , NoneTypeT ) :
432430 try :
433- type_s = get_underlying_scalar_constant_value (s )
434- if type_s is not None :
435- type_shape [i ] = int (type_s )
431+ type_shape [i ] = int (ptb .get_underlying_scalar_constant_value (s ))
436432 except NotScalarConstantError :
437433 pass
438434
@@ -460,22 +456,13 @@ def perform(self, node, inp, out_):
460456 def infer_shape (self , fgraph , node , shapes ):
461457 xshape , * _ = shapes
462458 shape = node .inputs [1 :]
463- new_shape = []
464- for dim in range (node .inputs [0 ].type .ndim ):
465- s = shape [dim ]
466- try :
467- s = ptb .get_underlying_scalar_constant_value (s )
468- # We assume that `None` shapes are always retrieved by
469- # `get_underlying_scalar_constant_value`, and only in that case do we default to
470- # the shape of the input variable
471- if s is None :
472- s = xshape [dim ]
473- except NotScalarConstantError :
474- pass
475- new_shape .append (ptb .as_tensor_variable (s ))
476-
477- assert len (new_shape ) == len (xshape )
478- return [new_shape ]
459+ # Use x shape if specified dim is None, otherwise the specified shape
460+ return [
461+ [
462+ xshape [i ] if isinstance (dim .type , NoneTypeT ) else dim
463+ for i , dim in enumerate (shape )
464+ ]
465+ ]
479466
480467 def connection_pattern (self , node ):
481468 return [[True ], * [[False ]] * len (node .inputs [1 :])]
0 commit comments