@@ -267,27 +267,7 @@ def _obj_is_wrappable_as_tensor(x):
267267)
268268
269269
270- def get_scalar_constant_value (
271- v , elemwise = True , only_process_constants = False , max_recur = 10
272- ):
273- """
274- Checks whether 'v' is a scalar (ndim = 0).
275-
276- If 'v' is a scalar then this function fetches the underlying constant by calling
277- 'get_underlying_scalar_constant_value()'.
278-
279- If 'v' is not a scalar, it raises a NotScalarConstantError.
280-
281- """
282- if isinstance (v , Variable | np .ndarray ):
283- if v .ndim != 0 :
284- raise NotScalarConstantError ()
285- return get_underlying_scalar_constant_value (
286- v , elemwise , only_process_constants , max_recur
287- )
288-
289-
290- def get_underlying_scalar_constant_value (
270+ def _get_underlying_scalar_constant_value (
291271 orig_v , elemwise = True , only_process_constants = False , max_recur = 10
292272):
293273 """Return the constant scalar(0-D) value underlying variable `v`.
@@ -377,7 +357,7 @@ def get_underlying_scalar_constant_value(
377357 elif isinstance (op , CheckAndRaise ):
378358 # check if all conditions are constant and true
379359 conds = [
380- get_underlying_scalar_constant_value (c , max_recur = max_recur )
360+ _get_underlying_scalar_constant_value (c , max_recur = max_recur )
381361 for c in v .owner .inputs [1 :]
382362 ]
383363 if builtins .all (0 == c .ndim and c != 0 for c in conds ):
@@ -391,7 +371,7 @@ def get_underlying_scalar_constant_value(
391371 continue
392372 if isinstance (v .owner .op , _scalar_constant_value_elemwise_ops ):
393373 const = [
394- get_underlying_scalar_constant_value (i , max_recur = max_recur )
374+ _get_underlying_scalar_constant_value (i , max_recur = max_recur )
395375 for i in v .owner .inputs
396376 ]
397377 ret = [[None ]]
@@ -410,7 +390,7 @@ def get_underlying_scalar_constant_value(
410390 v .owner .op .scalar_op , _scalar_constant_value_elemwise_ops
411391 ):
412392 const = [
413- get_underlying_scalar_constant_value (i , max_recur = max_recur )
393+ _get_underlying_scalar_constant_value (i , max_recur = max_recur )
414394 for i in v .owner .inputs
415395 ]
416396 ret = [[None ]]
@@ -453,7 +433,7 @@ def get_underlying_scalar_constant_value(
453433 ):
454434 idx = v .owner .op .idx_list [0 ]
455435 if isinstance (idx , Type ):
456- idx = get_underlying_scalar_constant_value (
436+ idx = _get_underlying_scalar_constant_value (
457437 v .owner .inputs [1 ], max_recur = max_recur
458438 )
459439 try :
@@ -487,14 +467,13 @@ def get_underlying_scalar_constant_value(
487467 ):
488468 idx = v .owner .op .idx_list [0 ]
489469 if isinstance (idx , Type ):
490- idx = get_underlying_scalar_constant_value (
470+ idx = _get_underlying_scalar_constant_value (
491471 v .owner .inputs [1 ], max_recur = max_recur
492472 )
493- # Python 2.4 does not support indexing with numpy.integer
494- # So we cast it.
495- idx = int (idx )
496473 ret = v .owner .inputs [0 ].owner .inputs [idx ]
497- ret = get_underlying_scalar_constant_value (ret , max_recur = max_recur )
474+ ret = _get_underlying_scalar_constant_value (
475+ ret , max_recur = max_recur
476+ )
498477 # MakeVector can cast implicitly its input in some case.
499478 return np .asarray (ret , dtype = v .type .dtype )
500479
@@ -509,7 +488,7 @@ def get_underlying_scalar_constant_value(
509488 idx_list = op .idx_list
510489 idx = idx_list [0 ]
511490 if isinstance (idx , Type ):
512- idx = get_underlying_scalar_constant_value (
491+ idx = _get_underlying_scalar_constant_value (
513492 owner .inputs [1 ], max_recur = max_recur
514493 )
515494 grandparent = leftmost_parent .owner .inputs [0 ]
@@ -519,7 +498,9 @@ def get_underlying_scalar_constant_value(
519498 grandparent .owner .op , Unbroadcast
520499 ):
521500 ggp_shape = grandparent .owner .inputs [0 ].type .shape
522- l = [get_underlying_scalar_constant_value (s ) for s in ggp_shape ]
501+ l = [
502+ _get_underlying_scalar_constant_value (s ) for s in ggp_shape
503+ ]
523504 gp_shape = tuple (l )
524505
525506 if not (idx < ndim ):
@@ -541,7 +522,7 @@ def get_underlying_scalar_constant_value(
541522 if isinstance (grandparent , Constant ):
542523 return np .asarray (np .shape (grandparent .data )[idx ])
543524 elif isinstance (op , CSM ):
544- data = get_underlying_scalar_constant_value (
525+ data = _get_underlying_scalar_constant_value (
545526 v .owner .inputs , elemwise = elemwise , max_recur = max_recur
546527 )
547528 # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
@@ -552,6 +533,92 @@ def get_underlying_scalar_constant_value(
552533 raise NotScalarConstantError ()
553534
554535
536+ def get_underlying_scalar_constant_value (
537+ v ,
538+ * ,
539+ elemwise = True ,
540+ only_process_constants = False ,
541+ max_recur = 10 ,
542+ raise_not_constant = True ,
543+ ):
544+ """Return the unique constant scalar(0-D) value underlying variable `v`.
545+
546+ If `v` is the output of dimshuffles, fills, allocs, etc,
547+ cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
548+ and some pattern with Subtensor, this function digs through them.
549+
550+ If `v` is not some view of constant scalar data, then raise a
551+ NotScalarConstantError.
552+
553+ This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
554+ constant folding the inputs of `v`.
555+
556+ Parameters
557+ ----------
558+ v: Variable
559+ elemwise : bool
560+ If False, we won't try to go into elemwise. So this call is faster.
561+ But we still investigate in Second Elemwise (as this is a substitute
562+ for Alloc)
563+ only_process_constants : bool
564+ If True, we only attempt to obtain the value of `orig_v` if it's
565+ directly constant and don't try to dig through dimshuffles, fills,
566+ allocs, and other to figure out its value.
567+ max_recur : int
568+ The maximum number of recursion.
569+ raise_not_constant: bool, default True
570+ If True, raise a NotScalarConstantError if `v` does not have an
571+ underlying constant scalar value. If False, return `v` as is.
572+
573+
574+ Raises
575+ ------
576+ NotScalarConstantError
577+ `v` does not have an underlying constant scalar value.
578+ Only rasise if raise_not_constant is True.
579+
580+ """
581+ try :
582+ return _get_underlying_scalar_constant_value (
583+ v ,
584+ elemwise = elemwise ,
585+ only_process_constants = only_process_constants ,
586+ max_recur = max_recur ,
587+ )
588+ except NotScalarConstantError :
589+ if raise_not_constant :
590+ raise
591+ return v
592+
593+
594+ def get_scalar_constant_value (
595+ v ,
596+ elemwise = True ,
597+ only_process_constants = False ,
598+ max_recur = 10 ,
599+ raise_not_constant : bool = True ,
600+ ):
601+ """
602+ Checks whether 'v' is a scalar (ndim = 0).
603+
604+ If 'v' is a scalar then this function fetches the underlying constant by calling
605+ 'get_underlying_scalar_constant_value()'.
606+
607+ If 'v' is not a scalar, it raises a TypeError.
608+
609+ """
610+ if isinstance (v , Variable | np .ndarray ):
611+ if v .ndim != 0 :
612+ raise TypeError ()
613+ return get_underlying_scalar_constant_value (
614+ v ,
615+ elemwise = elemwise ,
616+ only_process_constants = only_process_constants ,
617+ max_recur = max_recur ,
618+ raise_not_constant = raise_not_constant ,
619+ )
620+
621+
555622class TensorFromScalar (COp ):
556623 __props__ = ()
557624
@@ -2006,16 +2073,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
20062073 ScalarVariable, we convert it to a tensor with tensor_from_scalar.
20072074
20082075 """
2009- try :
2010- x = get_underlying_scalar_constant_value (x , elemwise , only_process_constants )
2011- except NotScalarConstantError :
2012- pass
2013- if isinstance ( x , ps . ScalarVariable | ps . sharedvar . ScalarSharedVariable ):
2014- if x . owner and isinstance ( x . owner . op , ScalarFromTensor ):
2015- x = x . owner . inputs [ 0 ]
2016- else :
2017- x = tensor_from_scalar ( x )
2018- return x
2076+ warnings . warn (
2077+ "extract_constant is deprecated. Use ` get_underlying_scalar_constant_value(..., raise_not_constant=False)`" ,
2078+ FutureWarning ,
2079+ )
2080+ return get_underlying_scalar_constant_value (
2081+ x ,
2082+ elemwise = elemwise ,
2083+ only_process_constants = only_process_constants ,
2084+ raise_not_constant = False ,
2085+ )
20192086
20202087
20212088def transpose (x , axes = None ):
@@ -4394,7 +4461,6 @@ def ix_(*args):
43944461 "split" ,
43954462 "transpose" ,
43964463 "matrix_transpose" ,
4397- "extract_constant" ,
43984464 "default" ,
43994465 "tensor_copy" ,
44004466 "transfer" ,
0 commit comments