@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
320
320
321
321
"""
322
322
from pytensor .compile .ops import DeepCopyOp , OutputGuard
323
+ from pytensor .sparse import CSM
324
+ from pytensor .tensor .subtensor import Subtensor
323
325
324
326
v = orig_v
325
327
while True :
@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
350
352
raise NotScalarConstantError ()
351
353
352
354
if not only_process_constants and getattr (v , "owner" , None ) and max_recur > 0 :
355
+ op = v .owner .op
353
356
max_recur -= 1
354
357
if isinstance (
355
- v .owner .op ,
356
- Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ,
358
+ op , Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
357
359
):
358
360
# OutputGuard is only used in debugmode but we
359
361
# keep it here to avoid problems with old pickles
360
362
v = v .owner .inputs [0 ]
361
363
continue
362
- elif isinstance (v . owner . op , Shape_i ):
364
+ elif isinstance (op , Shape_i ):
363
365
i = v .owner .op .i
364
366
inp = v .owner .inputs [0 ]
365
367
if isinstance (inp , Constant ):
@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
373
375
# mess with the stabilization optimization and be too slow.
374
376
# We put all the scalar Ops used by get_canonical_form_slice()
375
377
# to allow it to determine the broadcast pattern correctly.
376
- elif isinstance (v . owner . op , ScalarFromTensor | TensorFromScalar ):
378
+ elif isinstance (op , ScalarFromTensor | TensorFromScalar ):
377
379
v = v .owner .inputs [0 ]
378
380
continue
379
- elif isinstance (v . owner . op , CheckAndRaise ):
381
+ elif isinstance (op , CheckAndRaise ):
380
382
# check if all conditions are constant and true
381
383
conds = [
382
384
get_underlying_scalar_constant_value (c , max_recur = max_recur )
@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
385
387
if builtins .all (0 == c .ndim and c != 0 for c in conds ):
386
388
v = v .owner .inputs [0 ]
387
389
continue
388
- elif isinstance (v . owner . op , ps .ScalarOp ):
390
+ elif isinstance (op , ps .ScalarOp ):
389
391
if isinstance (v .owner .op , ps .Second ):
390
392
# We don't need both input to be constant for second
391
393
shp , val = v .owner .inputs
@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
402
404
# In fast_compile, we don't enable local_fill_to_alloc, so
403
405
# we need to investigate Second as Alloc. So elemwise
404
406
# don't disable the check for Second.
405
- elif isinstance (v . owner . op , Elemwise ):
407
+ elif isinstance (op , Elemwise ):
406
408
if isinstance (v .owner .op .scalar_op , ps .Second ):
407
409
# We don't need both input to be constant for second
408
410
shp , val = v .owner .inputs
@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
418
420
ret = [[None ]]
419
421
v .owner .op .perform (v .owner , const , ret )
420
422
return np .asarray (ret [0 ][0 ].copy ())
421
- elif (
422
- isinstance (v .owner .op , pytensor .tensor .subtensor .Subtensor )
423
- and v .ndim == 0
424
- ):
423
+ elif isinstance (op , Subtensor ) and v .ndim == 0 :
425
424
if isinstance (v .owner .inputs [0 ], TensorConstant ):
426
425
from pytensor .tensor .subtensor import get_constant_idx
427
426
@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
545
544
546
545
if isinstance (grandparent , Constant ):
547
546
return np .asarray (np .shape (grandparent .data )[idx ])
547
+ elif isinstance (op , CSM ):
548
+ data = get_underlying_scalar_constant_value (
549
+ v .owner .inputs , elemwise = elemwise , max_recur = max_recur
550
+ )
551
+ # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
552
+ if data == 0 :
553
+ return data
554
+ break
548
555
549
556
raise NotScalarConstantError ()
550
557
@@ -4071,7 +4078,7 @@ def make_node(self, a, choices):
4071
4078
static_out_shape = ()
4072
4079
for s in out_shape :
4073
4080
try :
4074
- s_val = pytensor . get_underlying_scalar_constant (s )
4081
+ s_val = get_underlying_scalar_constant_value (s )
4075
4082
except (NotScalarConstantError , AttributeError ):
4076
4083
s_val = None
4077
4084
0 commit comments