Skip to content

Commit 54cbd9b

Browse files
authored
optimaltree hangs on edge case with all costs equal to 1 (#207)
* update `maxcost` * Add test * handle cases with more tensors * Refactor and improve * Fix typo
1 parent 0fad751 commit 54cbd9b

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

src/indexnotation/optimaltree.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,23 @@ function optimaltree(network, optdata::Dict; verbose::Bool=false)
44
numindices = length(allindices)
55
costtype = valtype(optdata)
66
allcosts = costtype[get(optdata, i, one(costtype)) for i in allindices]
7-
maxcost = addcost(mulcost(reduce(mulcost, allcosts; init=one(costtype)),
8-
maximum(allcosts)), zero(costtype)) # add zero for type stability: Power -> Poly
7+
98
tensorcosts = Vector{costtype}(undef, numtensors)
109
for k in 1:numtensors
1110
tensorcosts[k] = mapreduce(i -> get(optdata, i, one(costtype)), mulcost, network[k];
1211
init=one(costtype))
1312
end
14-
initialcost = min(maxcost,
15-
addcost(mulcost(maximum(tensorcosts), minimum(tensorcosts)),
16-
zero(costtype))) # just some arbitrary guess
13+
initialcost = addcost(mulcost(maximum(tensorcosts), minimum(tensorcosts)),
14+
zero(costtype)) # just some arbitrary guess
1715

1816
if numindices <= 32
19-
return _optimaltree(UInt32, network, allindices, allcosts, initialcost, maxcost;
20-
verbose=verbose)
17+
return _optimaltree(UInt32, network, allindices, allcosts, initialcost; verbose)
2118
elseif numindices <= 64
22-
return _optimaltree(UInt64, network, allindices, allcosts, initialcost, maxcost;
23-
verbose=verbose)
19+
return _optimaltree(UInt64, network, allindices, allcosts, initialcost; verbose)
2420
elseif numindices <= 128 && !(@static Int == Int32 && Sys.iswindows() ? true : false)
25-
return _optimaltree(UInt128, network, allindices, allcosts, initialcost, maxcost;
26-
verbose=verbose)
21+
return _optimaltree(UInt128, network, allindices, allcosts, initialcost; verbose)
2722
else
28-
return _optimaltree(BitVector, network, allindices, allcosts, initialcost, maxcost;
29-
verbose=verbose)
23+
return _optimaltree(BitVector, network, allindices, allcosts, initialcost; verbose)
3024
end
3125
end
3226

@@ -95,8 +89,23 @@ function computecost(allcosts, ind1::BitSet, ind2::BitSet)
9589
return cost
9690
end
9791

98-
function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C,
99-
maxcost::C; verbose::Bool=false) where {T,S,C}
92+
function computemaxcost(allcosts, indexsets)
93+
if length(indexsets) 1
94+
maxcost = one(eltype(allcosts))
95+
else
96+
maxcost = zero(eltype(allcosts))
97+
s1 = indexsets[1]
98+
for n in 2:length(indexsets)
99+
s2 = indexsets[n]
100+
maxcost = addcost(maxcost, computecost(allcosts, s1, s2))
101+
s1 = _setdiff(_union(s1, s2), _intersect(s1, s2))
102+
end
103+
end
104+
return addcost(maxcost, zero(maxcost)) # add zero for type stability: Power -> Poly
105+
end
106+
107+
function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C;
108+
verbose::Bool=false) where {T,S,C}
100109
numindices = length(allindices)
101110
numtensors = length(network)
102111
indexsets = Array{T}(undef, numtensors)
@@ -156,7 +165,8 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi
156165
end
157166

158167
# run over currentcost
159-
currentcost = initialcost
168+
maxcost = computemaxcost(allcosts, @view(indexsets[component]))
169+
currentcost = min(initialcost, maxcost)
160170
previouscost = zero(initialcost)
161171
while currentcost <= maxcost
162172
nextcost = maxcost

test/tensoropt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,16 @@
5555
1,
5656
0, 0, 0, 3, 0, 3, 2, 0, 2, 4])
5757
end
58+
59+
@testset "Issue #206" begin
60+
# https://github.com/Jutho/TensorOperations.jl/issues/206
61+
network = [[Symbol(1), Symbol(2)], [Symbol(1)], [Symbol(2)]]
62+
opt_data = Dict(Symbol(1) => 1, Symbol(2) => 1)
63+
tree, cost = TensorOperations.optimaltree(network, opt_data)
64+
@test cost == 2
65+
66+
network = network = [[:a, :b, :c], [:a], [:b], [:c]]
67+
opt_data = Dict(:a => 1, :b => 1, :c => 1)
68+
tree, cost = TensorOperations.optimaltree(network, opt_data)
69+
@test cost == 3
70+
end

0 commit comments

Comments
 (0)