@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
320320
321321 """
322322 from pytensor .compile .ops import DeepCopyOp , OutputGuard
323+ from pytensor .sparse import CSM
324+ from pytensor .tensor .subtensor import Subtensor
323325
324326 v = orig_v
325327 while True :
@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
350352 raise NotScalarConstantError ()
351353
352354 if not only_process_constants and getattr (v , "owner" , None ) and max_recur > 0 :
355+ op = v .owner .op
353356 max_recur -= 1
354357 if isinstance (
355- v .owner .op ,
356- Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ,
358+ op , Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
357359 ):
358360 # OutputGuard is only used in debugmode but we
359361 # keep it here to avoid problems with old pickles
360362 v = v .owner .inputs [0 ]
361363 continue
362- elif isinstance (v . owner . op , Shape_i ):
364+ elif isinstance (op , Shape_i ):
363365 i = v .owner .op .i
364366 inp = v .owner .inputs [0 ]
365367 if isinstance (inp , Constant ):
@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
373375 # mess with the stabilization optimization and be too slow.
374376 # We put all the scalar Ops used by get_canonical_form_slice()
375377 # to allow it to determine the broadcast pattern correctly.
376- elif isinstance (v . owner . op , ScalarFromTensor | TensorFromScalar ):
378+ elif isinstance (op , ScalarFromTensor | TensorFromScalar ):
377379 v = v .owner .inputs [0 ]
378380 continue
379- elif isinstance (v . owner . op , CheckAndRaise ):
381+ elif isinstance (op , CheckAndRaise ):
380382 # check if all conditions are constant and true
381383 conds = [
382384 get_underlying_scalar_constant_value (c , max_recur = max_recur )
@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
385387 if builtins .all (0 == c .ndim and c != 0 for c in conds ):
386388 v = v .owner .inputs [0 ]
387389 continue
388- elif isinstance (v . owner . op , ps .ScalarOp ):
390+ elif isinstance (op , ps .ScalarOp ):
389391 if isinstance (v .owner .op , ps .Second ):
390392 # We don't need both input to be constant for second
391393 shp , val = v .owner .inputs
@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
402404 # In fast_compile, we don't enable local_fill_to_alloc, so
403405 # we need to investigate Second as Alloc. So elemwise
404406 # don't disable the check for Second.
405- elif isinstance (v . owner . op , Elemwise ):
407+ elif isinstance (op , Elemwise ):
406408 if isinstance (v .owner .op .scalar_op , ps .Second ):
407409 # We don't need both input to be constant for second
408410 shp , val = v .owner .inputs
@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
418420 ret = [[None ]]
419421 v .owner .op .perform (v .owner , const , ret )
420422 return np .asarray (ret [0 ][0 ].copy ())
421- elif (
422- isinstance (v .owner .op , pytensor .tensor .subtensor .Subtensor )
423- and v .ndim == 0
424- ):
423+ elif isinstance (op , Subtensor ) and v .ndim == 0 :
425424 if isinstance (v .owner .inputs [0 ], TensorConstant ):
426425 from pytensor .tensor .subtensor import get_constant_idx
427426
@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
545544
546545 if isinstance (grandparent , Constant ):
547546 return np .asarray (np .shape (grandparent .data )[idx ])
547+ elif isinstance (op , CSM ):
548+ data = get_underlying_scalar_constant_value (
549+ v .owner .inputs , elemwise = elemwise , max_recur = max_recur
550+ )
551+ # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
552+ if data == 0 :
553+ return data
554+ break
548555
549556 raise NotScalarConstantError ()
550557
@@ -4071,7 +4078,7 @@ def make_node(self, a, choices):
40714078 static_out_shape = ()
40724079 for s in out_shape :
40734080 try :
4074- s_val = pytensor . get_underlying_scalar_constant (s )
4081+ s_val = get_underlying_scalar_constant_value (s )
40754082 except (NotScalarConstantError , AttributeError ):
40764083 s_val = None
40774084
0 commit comments