@@ -32,14 +32,16 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L}
3232 path = sumtraces (path)
3333 metric = config. metric (sizedict)
3434
35+ hashyperinds = ! isempty (hyperinds (path))
36+
3537 # generate initial candidate contractions
3638 queue = MutableBinaryHeap {Tuple{Float64,EinExpr{L}}} (
3739 Base. By (first, Base. Reverse),
3840 map (
3941 Iterators. filter (((a, b),) -> config. outer || ! isdisjoint (a. head, b. head), combinations (path. args, 2 )),
4042 ) do (a, b)
4143 # TODO don't consider outer products
42- candidate = sum ([a, b], skip = path. head ∪ hyperinds (path))
44+ candidate = sum ([a, b], skip = hashyperinds ? path. head ∪ hyperinds (path) : path . head )
4345 weight = metric (candidate)
4446 (weight, candidate)
4547 end ,
@@ -58,7 +60,7 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L}
5860 # update candidate queue
5961 for other in Iterators. filter (other -> config. outer || ! isdisjoint (winner. head, other. head), path. args)
6062 # TODO don't consider outer products
61- candidate = sum ([winner, other], skip = path. head ∪ hyperinds (path))
63+ candidate = sum ([winner, other], skip = hashyperinds ? path. head ∪ hyperinds (path) : path . head )
6264 weight = metric (candidate)
6365 push! (queue, (weight, candidate))
6466 end
0 commit comments