@@ -268,27 +268,7 @@ def _obj_is_wrappable_as_tensor(x):
268
268
)
269
269
270
270
271
- def get_scalar_constant_value (
272
- v , elemwise = True , only_process_constants = False , max_recur = 10
273
- ):
274
- """
275
- Checks whether 'v' is a scalar (ndim = 0).
276
-
277
- If 'v' is a scalar then this function fetches the underlying constant by calling
278
- 'get_underlying_scalar_constant_value()'.
279
-
280
- If 'v' is not a scalar, it raises a NotScalarConstantError.
281
-
282
- """
283
- if isinstance (v , Variable | np .ndarray ):
284
- if v .ndim != 0 :
285
- raise NotScalarConstantError ()
286
- return get_underlying_scalar_constant_value (
287
- v , elemwise , only_process_constants , max_recur
288
- )
289
-
290
-
291
- def get_underlying_scalar_constant_value (
271
+ def _get_underlying_scalar_constant_value (
292
272
orig_v , elemwise = True , only_process_constants = False , max_recur = 10
293
273
):
294
274
"""Return the constant scalar(0-D) value underlying variable `v`.
@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value(
381
361
elif isinstance (op , CheckAndRaise ):
382
362
# check if all conditions are constant and true
383
363
conds = [
384
- get_underlying_scalar_constant_value (c , max_recur = max_recur )
364
+ _get_underlying_scalar_constant_value (c , max_recur = max_recur )
385
365
for c in v .owner .inputs [1 :]
386
366
]
387
367
if builtins .all (0 == c .ndim and c != 0 for c in conds ):
@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value(
395
375
continue
396
376
if isinstance (v .owner .op , _scalar_constant_value_elemwise_ops ):
397
377
const = [
398
- get_underlying_scalar_constant_value (i , max_recur = max_recur )
378
+ _get_underlying_scalar_constant_value (i , max_recur = max_recur )
399
379
for i in v .owner .inputs
400
380
]
401
381
ret = [[None ]]
@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value(
414
394
v .owner .op .scalar_op , _scalar_constant_value_elemwise_ops
415
395
):
416
396
const = [
417
- get_underlying_scalar_constant_value (i , max_recur = max_recur )
397
+ _get_underlying_scalar_constant_value (i , max_recur = max_recur )
418
398
for i in v .owner .inputs
419
399
]
420
400
ret = [[None ]]
@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value(
457
437
):
458
438
idx = v .owner .op .idx_list [0 ]
459
439
if isinstance (idx , Type ):
460
- idx = get_underlying_scalar_constant_value (
440
+ idx = _get_underlying_scalar_constant_value (
461
441
v .owner .inputs [1 ], max_recur = max_recur
462
442
)
463
443
try :
@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value(
491
471
):
492
472
idx = v .owner .op .idx_list [0 ]
493
473
if isinstance (idx , Type ):
494
- idx = get_underlying_scalar_constant_value (
474
+ idx = _get_underlying_scalar_constant_value (
495
475
v .owner .inputs [1 ], max_recur = max_recur
496
476
)
497
- # Python 2.4 does not support indexing with numpy.integer
498
- # So we cast it.
499
- idx = int (idx )
500
477
ret = v .owner .inputs [0 ].owner .inputs [idx ]
501
- ret = get_underlying_scalar_constant_value (ret , max_recur = max_recur )
478
+ ret = _get_underlying_scalar_constant_value (
479
+ ret , max_recur = max_recur
480
+ )
502
481
# MakeVector can cast implicitly its input in some case.
503
482
return np .asarray (ret , dtype = v .type .dtype )
504
483
@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value(
513
492
idx_list = op .idx_list
514
493
idx = idx_list [0 ]
515
494
if isinstance (idx , Type ):
516
- idx = get_underlying_scalar_constant_value (
495
+ idx = _get_underlying_scalar_constant_value (
517
496
owner .inputs [1 ], max_recur = max_recur
518
497
)
519
498
grandparent = leftmost_parent .owner .inputs [0 ]
@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value(
523
502
grandparent .owner .op , Unbroadcast
524
503
):
525
504
ggp_shape = grandparent .owner .inputs [0 ].type .shape
526
- l = [get_underlying_scalar_constant_value (s ) for s in ggp_shape ]
505
+ l = [
506
+ _get_underlying_scalar_constant_value (s ) for s in ggp_shape
507
+ ]
527
508
gp_shape = tuple (l )
528
509
529
510
if not (idx < ndim ):
@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value(
545
526
if isinstance (grandparent , Constant ):
546
527
return np .asarray (np .shape (grandparent .data )[idx ])
547
528
elif isinstance (op , CSM ):
548
- data = get_underlying_scalar_constant_value (
529
+ data = _get_underlying_scalar_constant_value (
549
530
v .owner .inputs , elemwise = elemwise , max_recur = max_recur
550
531
)
551
532
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value(
556
537
raise NotScalarConstantError ()
557
538
558
539
540
+ def get_underlying_scalar_constant_value (
541
+ v ,
542
+ * ,
543
+ elemwise = True ,
544
+ only_process_constants = False ,
545
+ max_recur = 10 ,
546
+ raise_not_constant = True ,
547
+ ):
548
+ """Return the unique constant scalar(0-D) value underlying variable `v`.
549
+
550
+ If `v` is the output of dimshuffles, fills, allocs, etc,
551
+ cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
552
+ and some pattern with Subtensor, this function digs through them.
553
+
554
+ If `v` is not some view of constant scalar data, then raise a
555
+ NotScalarConstantError.
556
+
557
+ This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
558
+ constant folding the inputs of `v`.
559
+
560
+ Parameters
561
+ ----------
562
+ v: Variable
563
+ elemwise : bool
564
+ If False, we won't try to go into elemwise. So this call is faster.
565
+ But we still investigate in Second Elemwise (as this is a substitute
566
+ for Alloc)
567
+ only_process_constants : bool
568
+ If True, we only attempt to obtain the value of `orig_v` if it's
569
+ directly constant and don't try to dig through dimshuffles, fills,
570
+ allocs, and other to figure out its value.
571
+ max_recur : int
572
+ The maximum number of recursion.
573
+ raise_not_constant: bool, default True
574
+ If True, raise a NotScalarConstantError if `v` does not have an
575
+ underlying constant scalar value. If False, return `v` as is.
576
+
577
+
578
+ Raises
579
+ ------
580
+ NotScalarConstantError
581
+ `v` does not have an underlying constant scalar value.
582
+ Only rasise if raise_not_constant is True.
583
+
584
+ """
585
+ try :
586
+ return _get_underlying_scalar_constant_value (
587
+ v ,
588
+ elemwise = elemwise ,
589
+ only_process_constants = only_process_constants ,
590
+ max_recur = max_recur ,
591
+ )
592
+ except NotScalarConstantError :
593
+ if raise_not_constant :
594
+ raise
595
+ return v
596
+
597
+
598
+ def get_scalar_constant_value (
599
+ v ,
600
+ elemwise = True ,
601
+ only_process_constants = False ,
602
+ max_recur = 10 ,
603
+ raise_not_constant : bool = True ,
604
+ ):
605
+ """
606
+ Checks whether 'v' is a scalar (ndim = 0).
607
+
608
+ If 'v' is a scalar then this function fetches the underlying constant by calling
609
+ 'get_underlying_scalar_constant_value()'.
610
+
611
+ If 'v' is not a scalar, it raises a NotScalarConstantError.
612
+
613
+ """
614
+ if isinstance (v , TensorVariable | np .ndarray ):
615
+ if v .ndim != 0 :
616
+ print (v , v .ndim )
617
+ raise NotScalarConstantError ("Input ndim != 0" )
618
+ return get_underlying_scalar_constant_value (
619
+ v ,
620
+ elemwise = elemwise ,
621
+ only_process_constants = only_process_constants ,
622
+ max_recur = max_recur ,
623
+ raise_not_constant = raise_not_constant ,
624
+ )
625
+
626
+
559
627
class TensorFromScalar (COp ):
560
628
__props__ = ()
561
629
@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
2012
2080
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
2013
2081
2014
2082
"""
2015
- try :
2016
- x = get_underlying_scalar_constant_value (x , elemwise , only_process_constants )
2017
- except NotScalarConstantError :
2018
- pass
2019
- if isinstance ( x , ps . ScalarVariable | ps . sharedvar . ScalarSharedVariable ):
2020
- if x . owner and isinstance ( x . owner . op , ScalarFromTensor ):
2021
- x = x . owner . inputs [ 0 ]
2022
- else :
2023
- x = tensor_from_scalar ( x )
2024
- return x
2083
+ warnings . warn (
2084
+ "extract_constant is deprecated. Use ` get_underlying_scalar_constant_value(..., raise_not_constant=False)`" ,
2085
+ FutureWarning ,
2086
+ )
2087
+ return get_underlying_scalar_constant_value (
2088
+ x ,
2089
+ elemwise = elemwise ,
2090
+ only_process_constants = only_process_constants ,
2091
+ raise_not_constant = False ,
2092
+ )
2025
2093
2026
2094
2027
2095
def transpose (x , axes = None ):
@@ -4401,7 +4469,6 @@ def ix_(*args):
4401
4469
"split" ,
4402
4470
"transpose" ,
4403
4471
"matrix_transpose" ,
4404
- "extract_constant" ,
4405
4472
"default" ,
4406
4473
"tensor_copy" ,
4407
4474
"transfer" ,
0 commit comments