@@ -426,27 +426,58 @@ def smart_subs_dict(
426426 else :
427427 s = [(eid , expr [field ]) for eid , expr in subs .items ()]
428428
429- if not reverse :
430- # counter-intuitive, but we need to reverse the order for reverse=False
431- s .reverse ()
432-
433- with sp .evaluate (False ):
434- # The new expressions may themselves contain symbols to be substituted.
435- # We flatten them out first, so that the substitutions in `sym` can be
436- # performed simultaneously, which is usually more efficient than
437- # repeatedly substituting into `sym`.
438- # TODO(performance): This could probably be made more efficient by
439- # combining with toposort used to order `subs` in the first place.
440- # Some substitutions could be combined, and some terms not present in
441- # `sym` could be skipped.
442- for i in range (len (s ) - 1 ):
443- for j in range (i + 1 , len (s )):
444- if s [j ][1 ].has (s [i ][0 ]):
445- s [j ] = s [j ][0 ], s [j ][1 ].xreplace ({s [i ][0 ]: s [i ][1 ]})
446-
447- s = dict (s )
448- sym = sym .xreplace (s )
449- return sym
429+ # We have the choice to flatten the replacement expressions first or to
430+ # substitute them one by one into `sym`. Flattening first is usually
431+ # more efficient if `sym` is large (e.g., a matrix with many elements)
432+ # and `subs` is cascading (i.e., substitutions depend on other
433+ # substitutions). Otherwise, substituting one by one is usually more
434+ # efficient, because flattening scales quadratically with the number of
435+ # substitutions.
436+ # The exact threshold is somewhat arbitrary and may need to be
437+ # adjusted in the future.
438+ flatten_first = isinstance (sym , sp .MatrixBase ) and sym .rows * sym .cols > 20
439+
440+ if flatten_first :
441+ if not reverse :
442+ # counter-intuitive, but on this branch, we need to reverse the
443+ # order for `reverse=False`
444+ s .reverse ()
445+
446+ with sp .evaluate (False ):
447+ # The new expressions may themselves contain symbols to be
448+ # substituted. We flatten them out first, so that the
449+ # substitutions in `sym` can be performed simultaneously,
450+ # which can be more efficient than repeatedly substituting into
451+ # `sym`.
452+ # TODO(performance): This could probably be made more efficient by
453+ # combining with toposort used to order `subs` in the first
454+ # place.
455+ # Some substitutions could be combined, and some terms not
456+ # present in `sym` could be skipped.
457+ # Furthermore, this would provide information on recursion depth,
458+ # which might help decide which strategy is more efficient.
459+ # For flat hierarchies, substituting one by one is most likely
460+ # more efficient.
461+ for i in range (len (s ) - 1 ):
462+ for j in range (i + 1 , len (s )):
463+ if s [j ][1 ].has (s [i ][0 ]):
464+ s [j ] = s [j ][0 ], s [j ][1 ].xreplace ({s [i ][0 ]: s [i ][1 ]})
465+
466+ s = dict (s )
467+ sym = sym .xreplace (s )
468+ return sym
469+
470+ else :
471+ if reverse :
472+ s .reverse ()
473+
474+ with sp .evaluate (False ):
475+ for old , new in s :
476+ # note that substitution may change free symbols,
477+ # so we have to do this recursively
478+ if sym .has (old ):
479+ sym = sym .xreplace ({old : new })
480+ return sym
450481
451482
452483def smart_subs (element : sp .Expr , old : sp .Symbol , new : sp .Expr ) -> sp .Expr :
0 commit comments