1919
2020import pytensor
2121import pytensor .scalar .sharedvar
22- from pytensor import compile , config , printing
22+ from pytensor import config , printing
2323from pytensor import scalar as ps
2424from pytensor .compile .builders import OpFromGraph
2525from pytensor .gradient import DisconnectedType , grad_undefined
3535from pytensor .printing import Printer , min_informative_str , pprint , set_precedence
3636from pytensor .raise_op import CheckAndRaise , assert_op
3737from pytensor .scalar import int32
38- from pytensor .scalar .basic import ScalarConstant , ScalarVariable
38+ from pytensor .scalar .basic import ScalarConstant , ScalarType , ScalarVariable
3939from pytensor .tensor import (
4040 _as_tensor_variable ,
4141 _get_vector_length ,
7171 uint_dtypes ,
7272 values_eq_approx_always_true ,
7373)
74+ from pytensor .tensor .type_other import NoneTypeT
7475from pytensor .tensor .variable import (
7576 TensorConstant ,
7677 TensorVariable ,
77- get_unique_constant_value ,
7878)
7979
8080
@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value(
319319 but I'm not sure where it is.
320320
321321 """
322+ from pytensor .compile .ops import DeepCopyOp , OutputGuard
323+
322324 v = orig_v
323325 while True :
324326 if v is None :
@@ -336,34 +338,22 @@ def get_underlying_scalar_constant_value(
336338 raise NotScalarConstantError ()
337339
338340 if isinstance (v , Constant ):
339- unique_value = get_unique_constant_value (v )
340- if unique_value is not None :
341- data = unique_value
342- else :
343- data = v .data
344-
345- if isinstance (data , np .ndarray ):
346- try :
347- return np .array (data .item (), dtype = v .dtype )
348- except ValueError :
349- raise NotScalarConstantError ()
341+ if isinstance (v .type , TensorType ) and v .unique_value is not None :
342+ return v .unique_value
350343
351- from pytensor .sparse .type import SparseTensorType
344+ elif isinstance (v .type , ScalarType ):
345+ return v .data
352346
353- if isinstance (v .type , SparseTensorType ):
354- raise NotScalarConstantError ()
347+ elif isinstance (v .type , NoneTypeT ):
348+ return None
355349
356- return data
350+ raise NotScalarConstantError ()
357351
358352 if not only_process_constants and getattr (v , "owner" , None ) and max_recur > 0 :
359353 max_recur -= 1
360354 if isinstance (
361355 v .owner .op ,
362- Alloc
363- | DimShuffle
364- | Unbroadcast
365- | compile .ops .OutputGuard
366- | compile .DeepCopyOp ,
356+ Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ,
367357 ):
368358 # OutputGuard is only used in debugmode but we
369359 # keep it here to avoid problems with old pickles
0 commit comments