diff --git a/src/indexnotation/optimaltree.jl b/src/indexnotation/optimaltree.jl index 5e1fd101..f8859bb9 100644 --- a/src/indexnotation/optimaltree.jl +++ b/src/indexnotation/optimaltree.jl @@ -4,29 +4,23 @@ function optimaltree(network, optdata::Dict; verbose::Bool=false) numindices = length(allindices) costtype = valtype(optdata) allcosts = costtype[get(optdata, i, one(costtype)) for i in allindices] - maxcost = addcost(mulcost(reduce(mulcost, allcosts; init=one(costtype)), - maximum(allcosts)), zero(costtype)) # add zero for type stability: Power -> Poly + tensorcosts = Vector{costtype}(undef, numtensors) for k in 1:numtensors tensorcosts[k] = mapreduce(i -> get(optdata, i, one(costtype)), mulcost, network[k]; init=one(costtype)) end - initialcost = min(maxcost, - addcost(mulcost(maximum(tensorcosts), minimum(tensorcosts)), - zero(costtype))) # just some arbitrary guess + initialcost = addcost(mulcost(maximum(tensorcosts), minimum(tensorcosts)), + zero(costtype)) # just some arbitrary guess if numindices <= 32 - return _optimaltree(UInt32, network, allindices, allcosts, initialcost, maxcost; - verbose=verbose) + return _optimaltree(UInt32, network, allindices, allcosts, initialcost; verbose) elseif numindices <= 64 - return _optimaltree(UInt64, network, allindices, allcosts, initialcost, maxcost; - verbose=verbose) + return _optimaltree(UInt64, network, allindices, allcosts, initialcost; verbose) elseif numindices <= 128 && !(@static Int == Int32 && Sys.iswindows() ? true : false) - return _optimaltree(UInt128, network, allindices, allcosts, initialcost, maxcost; - verbose=verbose) + return _optimaltree(UInt128, network, allindices, allcosts, initialcost; verbose) else - return _optimaltree(BitVector, network, allindices, allcosts, initialcost, maxcost; - verbose=verbose) + return _optimaltree(BitVector, network, allindices, allcosts, initialcost; verbose) end end @@ -95,8 +89,23 @@ function computecost(allcosts, ind1::BitSet, ind2::BitSet) return cost end -function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C, - maxcost::C; verbose::Bool=false) where {T,S,C} +function computemaxcost(allcosts, indexsets) + if length(indexsets) ≤ 1 + maxcost = one(eltype(allcosts)) + else + maxcost = zero(eltype(allcosts)) + s1 = indexsets[1] + for n in 2:length(indexsets) + s2 = indexsets[n] + maxcost = addcost(maxcost, computecost(allcosts, s1, s2)) + s1 = _setdiff(_union(s1, s2), _intersect(s1, s2)) + end + end + return addcost(maxcost, zero(maxcost)) # add zero for type stability: Power -> Poly +end + +function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C; + verbose::Bool=false) where {T,S,C} numindices = length(allindices) numtensors = length(network) indexsets = Array{T}(undef, numtensors) @@ -156,7 +165,8 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi end # run over currentcost - currentcost = initialcost + maxcost = computemaxcost(allcosts, @view(indexsets[component])) + currentcost = min(initialcost, maxcost) previouscost = zero(initialcost) while currentcost <= maxcost nextcost = maxcost diff --git a/test/tensoropt.jl b/test/tensoropt.jl index 1e221e64..093393ea 100644 --- a/test/tensoropt.jl +++ b/test/tensoropt.jl @@ -55,3 +55,16 @@ 1, 0, 0, 0, 3, 0, 3, 2, 0, 2, 4]) end + +@testset "Issue #206" begin + # https://github.com/Jutho/TensorOperations.jl/issues/206 + network = [[Symbol(1), Symbol(2)], [Symbol(1)], [Symbol(2)]] + opt_data = Dict(Symbol(1) => 1, Symbol(2) => 1) + tree, cost = TensorOperations.optimaltree(network, opt_data) + @test cost == 2 + + network = network = [[:a, :b, :c], [:a], [:b], [:c]] + opt_data = Dict(:a => 1, :b => 1, :c => 1) + tree, cost = TensorOperations.optimaltree(network, opt_data) + @test cost == 3 +end