@@ -400,6 +400,7 @@ def smart_subs_dict(
400400 subs : SymbolDef ,
401401 field : str | None = None ,
402402 reverse : bool = True ,
403+ flatten_first : bool | None = None ,
403404) -> sp .Expr :
404405 """
405406 Substitutes expressions completely flattening them out. Requires
@@ -418,6 +419,11 @@ def smart_subs_dict(
418419 Whether ordering in subs should be reversed. Note that substitution
419420 requires the reverse order of what is required for evaluation.
420421
422+ :param flatten_first:
423+ Choice of algorithm: Flatten the substitution expressions first, then
424+ substitute them simultaneously into `sym` (``True``), or substitute
425+ them one by one into `sym` (``False``).
426+
421427 :return:
422428 Substituted symbolic expression
423429 """
@@ -426,27 +432,61 @@ def smart_subs_dict(
426432 else :
427433 s = [(eid , expr [field ]) for eid , expr in subs .items ()]
428434
429- if not reverse :
430- # counter-intuitive, but we need to reverse the order for reverse=False
431- s .reverse ()
432-
433- with sp .evaluate (False ):
434- # The new expressions may themselves contain symbols to be substituted.
435- # We flatten them out first, so that the substitutions in `sym` can be
436- # performed simultaneously, which is usually more efficient than
437- # repeatedly substituting into `sym`.
438- # TODO(performance): This could probably be made more efficient by
439- # combining with toposort used to order `subs` in the first place.
440- # Some substitutions could be combined, and some terms not present in
441- # `sym` could be skipped.
442- for i in range (len (s ) - 1 ):
443- for j in range (i + 1 , len (s )):
444- if s [j ][1 ].has (s [i ][0 ]):
445- s [j ] = s [j ][0 ], s [j ][1 ].xreplace ({s [i ][0 ]: s [i ][1 ]})
446-
447- s = dict (s )
448- sym = sym .xreplace (s )
449- return sym
435+ # We have the choice to flatten the replacement expressions first or to
436+ # substitute them one by one into `sym`. Flattening first is usually
437+ # more efficient if `sym` is large (e.g., a matrix with many elements)
438+ # and `subs` is cascading (i.e., substitutions depend on other
439+ # substitutions). Otherwise, substituting one by one is usually more
440+ # efficient, because flattening scales quadratically with the number of
441+ # substitutions.
442+ # The exact threshold is somewhat arbitrary and may need to be
443+ # adjusted in the future.
444+ if flatten_first is None :
445+ flatten_first = (
446+ isinstance (sym , sp .MatrixBase ) and sym .rows * sym .cols > 20
447+ )
448+
449+ if flatten_first :
450+ if not reverse :
451+ # counter-intuitive, but on this branch, we need to reverse the
452+ # order for `reverse=False`
453+ s .reverse ()
454+
455+ with sp .evaluate (False ):
456+ # The new expressions may themselves contain symbols to be
457+ # substituted. We flatten them out first, so that the
458+ # substitutions in `sym` can be performed simultaneously,
459+ # which can be more efficient than repeatedly substituting into
460+ # `sym`.
461+ # TODO(performance): This could probably be made more efficient by
462+ # combining with toposort used to order `subs` in the first
463+ # place.
464+ # Some substitutions could be combined, and some terms not
465+ # present in `sym` could be skipped.
466+ # Furthermore, this would provide information on recursion depth,
467+ # which might help decide which strategy is more efficient.
468+ # For flat hierarchies, substituting one by one is most likely
469+ # more efficient.
470+ for i in range (len (s ) - 1 ):
471+ for j in range (i + 1 , len (s )):
472+ if s [j ][1 ].has (s [i ][0 ]):
473+ s [j ] = s [j ][0 ], s [j ][1 ].xreplace ({s [i ][0 ]: s [i ][1 ]})
474+
475+ s = dict (s )
476+ sym = sym .xreplace (s )
477+ return sym
478+
479+ else :
480+ if reverse :
481+ s .reverse ()
482+
483+ with sp .evaluate (False ):
484+ for old , new in s :
485+ # note that substitution may change free symbols,
486+ # so we have to do this recursively
487+ if sym .has (old ):
488+ sym = sym .xreplace ({old : new })
489+ return sym
450490
451491
452492def smart_subs (element : sp .Expr , old : sp .Symbol , new : sp .Expr ) -> sp .Expr :
0 commit comments