@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value(
319319
320320 """
321321 from pytensor .compile .ops import DeepCopyOp , OutputGuard
322+ from pytensor .sparse import CSM
323+ from pytensor .tensor .subtensor import Subtensor
322324
323325 v = orig_v
324326 while True :
@@ -346,16 +348,16 @@ def get_underlying_scalar_constant_value(
346348 raise NotScalarConstantError ()
347349
348350 if not only_process_constants and getattr (v , "owner" , None ) and max_recur > 0 :
351+ op = v .owner .op
349352 max_recur -= 1
350353 if isinstance (
351- v .owner .op ,
352- Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ,
354+ op , Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
353355 ):
354356 # OutputGuard is only used in debugmode but we
355357 # keep it here to avoid problems with old pickles
356358 v = v .owner .inputs [0 ]
357359 continue
358- elif isinstance (v . owner . op , Shape_i ):
360+ elif isinstance (op , Shape_i ):
359361 i = v .owner .op .i
360362 inp = v .owner .inputs [0 ]
361363 if isinstance (inp , Constant ):
@@ -369,10 +371,10 @@ def get_underlying_scalar_constant_value(
369371 # mess with the stabilization optimization and be too slow.
370372 # We put all the scalar Ops used by get_canonical_form_slice()
371373 # to allow it to determine the broadcast pattern correctly.
372- elif isinstance (v . owner . op , ScalarFromTensor | TensorFromScalar ):
374+ elif isinstance (op , ScalarFromTensor | TensorFromScalar ):
373375 v = v .owner .inputs [0 ]
374376 continue
375- elif isinstance (v . owner . op , CheckAndRaise ):
377+ elif isinstance (op , CheckAndRaise ):
376378 # check if all conditions are constant and true
377379 conds = [
378380 get_underlying_scalar_constant_value (c , max_recur = max_recur )
@@ -381,7 +383,7 @@ def get_underlying_scalar_constant_value(
381383 if builtins .all (0 == c .ndim and c != 0 for c in conds ):
382384 v = v .owner .inputs [0 ]
383385 continue
384- elif isinstance (v . owner . op , ps .ScalarOp ):
386+ elif isinstance (op , ps .ScalarOp ):
385387 if isinstance (v .owner .op , ps .Second ):
386388 # We don't need both input to be constant for second
387389 shp , val = v .owner .inputs
@@ -398,7 +400,7 @@ def get_underlying_scalar_constant_value(
398400 # In fast_compile, we don't enable local_fill_to_alloc, so
399401 # we need to investigate Second as Alloc. So elemwise
400402 # don't disable the check for Second.
401- elif isinstance (v . owner . op , Elemwise ):
403+ elif isinstance (op , Elemwise ):
402404 if isinstance (v .owner .op .scalar_op , ps .Second ):
403405 # We don't need both input to be constant for second
404406 shp , val = v .owner .inputs
@@ -414,10 +416,7 @@ def get_underlying_scalar_constant_value(
414416 ret = [[None ]]
415417 v .owner .op .perform (v .owner , const , ret )
416418 return np .asarray (ret [0 ][0 ].copy ())
417- elif (
418- isinstance (v .owner .op , pytensor .tensor .subtensor .Subtensor )
419- and v .ndim == 0
420- ):
419+ elif isinstance (op , Subtensor ) and v .ndim == 0 :
421420 if isinstance (v .owner .inputs [0 ], TensorConstant ):
422421 from pytensor .tensor .subtensor import get_constant_idx
423422
@@ -541,6 +540,14 @@ def get_underlying_scalar_constant_value(
541540
542541 if isinstance (grandparent , Constant ):
543542 return np .asarray (np .shape (grandparent .data )[idx ])
543+ elif isinstance (op , CSM ):
544+ data = get_underlying_scalar_constant_value (
545+ v .owner .inputs , elemwise = elemwise , max_recur = max_recur
546+ )
547+ # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
548+ if data == 0 :
549+ return data
550+ break
544551
545552 raise NotScalarConstantError ()
546553
@@ -4064,7 +4071,7 @@ def make_node(self, a, choices):
40644071 static_out_shape = ()
40654072 for s in out_shape :
40664073 try :
4067- s_val = pytensor . get_underlying_scalar_constant (s )
4074+ s_val = get_underlying_scalar_constant_value (s )
40684075 except (NotScalarConstantError , AttributeError ):
40694076 s_val = None
40704077
0 commit comments