@@ -268,27 +268,7 @@ def _obj_is_wrappable_as_tensor(x):
268268)
269269
270270
271- def get_scalar_constant_value (
272- v , elemwise = True , only_process_constants = False , max_recur = 10
273- ):
274- """
275- Checks whether 'v' is a scalar (ndim = 0).
276-
277- If 'v' is a scalar then this function fetches the underlying constant by calling
278- 'get_underlying_scalar_constant_value()'.
279-
280- If 'v' is not a scalar, it raises a NotScalarConstantError.
281-
282- """
283- if isinstance (v , Variable | np .ndarray ):
284- if v .ndim != 0 :
285- raise NotScalarConstantError ()
286- return get_underlying_scalar_constant_value (
287- v , elemwise , only_process_constants , max_recur
288- )
289-
290-
291- def get_underlying_scalar_constant_value (
271+ def _get_underlying_scalar_constant_value (
292272 orig_v , elemwise = True , only_process_constants = False , max_recur = 10
293273):
294274 """Return the constant scalar(0-D) value underlying variable `v`.
@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value(
381361 elif isinstance (op , CheckAndRaise ):
382362 # check if all conditions are constant and true
383363 conds = [
384- get_underlying_scalar_constant_value (c , max_recur = max_recur )
364+ _get_underlying_scalar_constant_value (c , max_recur = max_recur )
385365 for c in v .owner .inputs [1 :]
386366 ]
387367 if builtins .all (0 == c .ndim and c != 0 for c in conds ):
@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value(
395375 continue
396376 if isinstance (v .owner .op , _scalar_constant_value_elemwise_ops ):
397377 const = [
398- get_underlying_scalar_constant_value (i , max_recur = max_recur )
378+ _get_underlying_scalar_constant_value (i , max_recur = max_recur )
399379 for i in v .owner .inputs
400380 ]
401381 ret = [[None ]]
@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value(
414394 v .owner .op .scalar_op , _scalar_constant_value_elemwise_ops
415395 ):
416396 const = [
417- get_underlying_scalar_constant_value (i , max_recur = max_recur )
397+ _get_underlying_scalar_constant_value (i , max_recur = max_recur )
418398 for i in v .owner .inputs
419399 ]
420400 ret = [[None ]]
@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value(
457437 ):
458438 idx = v .owner .op .idx_list [0 ]
459439 if isinstance (idx , Type ):
460- idx = get_underlying_scalar_constant_value (
440+ idx = _get_underlying_scalar_constant_value (
461441 v .owner .inputs [1 ], max_recur = max_recur
462442 )
463443 try :
@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value(
491471 ):
492472 idx = v .owner .op .idx_list [0 ]
493473 if isinstance (idx , Type ):
494- idx = get_underlying_scalar_constant_value (
474+ idx = _get_underlying_scalar_constant_value (
495475 v .owner .inputs [1 ], max_recur = max_recur
496476 )
497- # Python 2.4 does not support indexing with numpy.integer
498- # So we cast it.
499- idx = int (idx )
500477 ret = v .owner .inputs [0 ].owner .inputs [idx ]
501- ret = get_underlying_scalar_constant_value (ret , max_recur = max_recur )
478+ ret = _get_underlying_scalar_constant_value (
479+ ret , max_recur = max_recur
480+ )
502481 # MakeVector can cast implicitly its input in some case.
503482 return np .asarray (ret , dtype = v .type .dtype )
504483
@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value(
513492 idx_list = op .idx_list
514493 idx = idx_list [0 ]
515494 if isinstance (idx , Type ):
516- idx = get_underlying_scalar_constant_value (
495+ idx = _get_underlying_scalar_constant_value (
517496 owner .inputs [1 ], max_recur = max_recur
518497 )
519498 grandparent = leftmost_parent .owner .inputs [0 ]
@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value(
523502 grandparent .owner .op , Unbroadcast
524503 ):
525504 ggp_shape = grandparent .owner .inputs [0 ].type .shape
526- l = [get_underlying_scalar_constant_value (s ) for s in ggp_shape ]
505+ l = [
506+ _get_underlying_scalar_constant_value (s ) for s in ggp_shape
507+ ]
527508 gp_shape = tuple (l )
528509
529510 if not (idx < ndim ):
@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value(
545526 if isinstance (grandparent , Constant ):
546527 return np .asarray (np .shape (grandparent .data )[idx ])
547528 elif isinstance (op , CSM ):
548- data = get_underlying_scalar_constant_value (
529+ data = _get_underlying_scalar_constant_value (
549530 v .owner .inputs , elemwise = elemwise , max_recur = max_recur
550531 )
551532 # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value(
556537 raise NotScalarConstantError ()
557538
558539
540+ def get_underlying_scalar_constant_value (
541+ v ,
542+ * ,
543+ elemwise = True ,
544+ only_process_constants = False ,
545+ max_recur = 10 ,
546+ raise_not_constant = True ,
547+ ):
548+ """Return the unique constant scalar(0-D) value underlying variable `v`.
549+
550+ If `v` is the output of dimshuffles, fills, allocs, etc,
551+ cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
552+ and some pattern with Subtensor, this function digs through them.
553+
554+ If `v` is not some view of constant scalar data, then raise a
555+ NotScalarConstantError.
556+
557+ This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
558+ constant folding the inputs of `v`.
559+
560+ Parameters
561+ ----------
562+ v: Variable
563+ elemwise : bool
564+ If False, we won't try to go into elemwise. So this call is faster.
565+ But we still investigate in Second Elemwise (as this is a substitute
566+ for Alloc)
567+ only_process_constants : bool
568+ If True, we only attempt to obtain the value of `orig_v` if it's
569+ directly constant and don't try to dig through dimshuffles, fills,
570+ allocs, and other to figure out its value.
571+ max_recur : int
572+ The maximum number of recursion.
573+ raise_not_constant: bool, default True
574+ If True, raise a NotScalarConstantError if `v` does not have an
575+ underlying constant scalar value. If False, return `v` as is.
576+
577+
578+ Raises
579+ ------
580+ NotScalarConstantError
581+ `v` does not have an underlying constant scalar value.
582+ Only rasise if raise_not_constant is True.
583+
584+ """
585+ try :
586+ return _get_underlying_scalar_constant_value (
587+ v ,
588+ elemwise = elemwise ,
589+ only_process_constants = only_process_constants ,
590+ max_recur = max_recur ,
591+ )
592+ except NotScalarConstantError :
593+ if raise_not_constant :
594+ raise
595+ return v
596+
597+
598+ def get_scalar_constant_value (
599+ v ,
600+ elemwise = True ,
601+ only_process_constants = False ,
602+ max_recur = 10 ,
603+ raise_not_constant : bool = True ,
604+ ):
605+ """
606+ Checks whether 'v' is a scalar (ndim = 0).
607+
608+ If 'v' is a scalar then this function fetches the underlying constant by calling
609+ 'get_underlying_scalar_constant_value()'.
610+
611+ If 'v' is not a scalar, it raises a NotScalarConstantError.
612+
613+ """
614+ if isinstance (v , TensorVariable | np .ndarray ):
615+ if v .ndim != 0 :
616+ print (v , v .ndim )
617+ raise NotScalarConstantError ("Input ndim != 0" )
618+ return get_underlying_scalar_constant_value (
619+ v ,
620+ elemwise = elemwise ,
621+ only_process_constants = only_process_constants ,
622+ max_recur = max_recur ,
623+ raise_not_constant = raise_not_constant ,
624+ )
625+
626+
559627class TensorFromScalar (COp ):
560628 __props__ = ()
561629
@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
20122080 ScalarVariable, we convert it to a tensor with tensor_from_scalar.
20132081
20142082 """
2015- try :
2016- x = get_underlying_scalar_constant_value (x , elemwise , only_process_constants )
2017- except NotScalarConstantError :
2018- pass
2019- if isinstance ( x , ps . ScalarVariable | ps . sharedvar . ScalarSharedVariable ):
2020- if x . owner and isinstance ( x . owner . op , ScalarFromTensor ):
2021- x = x . owner . inputs [ 0 ]
2022- else :
2023- x = tensor_from_scalar ( x )
2024- return x
2083+ warnings . warn (
2084+ "extract_constant is deprecated. Use ` get_underlying_scalar_constant_value(..., raise_not_constant=False)`" ,
2085+ FutureWarning ,
2086+ )
2087+ return get_underlying_scalar_constant_value (
2088+ x ,
2089+ elemwise = elemwise ,
2090+ only_process_constants = only_process_constants ,
2091+ raise_not_constant = False ,
2092+ )
20252093
20262094
20272095def transpose (x , axes = None ):
@@ -4401,7 +4469,6 @@ def ix_(*args):
44014469 "split" ,
44024470 "transpose" ,
44034471 "matrix_transpose" ,
4404- "extract_constant" ,
44054472 "default" ,
44064473 "tensor_copy" ,
44074474 "transfer" ,
0 commit comments