@@ -4,29 +4,23 @@ function optimaltree(network, optdata::Dict; verbose::Bool=false)
44 numindices = length (allindices)
55 costtype = valtype (optdata)
66 allcosts = costtype[get (optdata, i, one (costtype)) for i in allindices]
7- maxcost = addcost (mulcost (reduce (mulcost, allcosts; init= one (costtype)),
8- maximum (allcosts)), zero (costtype)) # add zero for type stability: Power -> Poly
7+
98 tensorcosts = Vector {costtype} (undef, numtensors)
109 for k in 1 : numtensors
1110 tensorcosts[k] = mapreduce (i -> get (optdata, i, one (costtype)), mulcost, network[k];
1211 init= one (costtype))
1312 end
14- initialcost = min (maxcost,
15- addcost (mulcost (maximum (tensorcosts), minimum (tensorcosts)),
16- zero (costtype))) # just some arbitrary guess
13+ initialcost = addcost (mulcost (maximum (tensorcosts), minimum (tensorcosts)),
14+ zero (costtype)) # just some arbitrary guess
1715
1816 if numindices <= 32
19- return _optimaltree (UInt32, network, allindices, allcosts, initialcost, maxcost;
20- verbose= verbose)
17+ return _optimaltree (UInt32, network, allindices, allcosts, initialcost; verbose)
2118 elseif numindices <= 64
22- return _optimaltree (UInt64, network, allindices, allcosts, initialcost, maxcost;
23- verbose= verbose)
19+ return _optimaltree (UInt64, network, allindices, allcosts, initialcost; verbose)
2420 elseif numindices <= 128 && ! (@static Int == Int32 && Sys. iswindows () ? true : false )
25- return _optimaltree (UInt128, network, allindices, allcosts, initialcost, maxcost;
26- verbose= verbose)
21+ return _optimaltree (UInt128, network, allindices, allcosts, initialcost; verbose)
2722 else
28- return _optimaltree (BitVector, network, allindices, allcosts, initialcost, maxcost;
29- verbose= verbose)
23+ return _optimaltree (BitVector, network, allindices, allcosts, initialcost; verbose)
3024 end
3125end
3226
@@ -95,8 +89,23 @@ function computecost(allcosts, ind1::BitSet, ind2::BitSet)
9589 return cost
9690end
9791
98- function _optimaltree (:: Type{T} , network, allindices, allcosts:: Vector{S} , initialcost:: C ,
99- maxcost:: C ; verbose:: Bool = false ) where {T,S,C}
92+ function computemaxcost (allcosts, indexsets)
93+ if length (indexsets) ≤ 1
94+ maxcost = one (eltype (allcosts))
95+ else
96+ maxcost = zero (eltype (allcosts))
97+ s1 = indexsets[1 ]
98+ for n in 2 : length (indexsets)
99+ s2 = indexsets[n]
100+ maxcost = addcost (maxcost, computecost (allcosts, s1, s2))
101+ s1 = _setdiff (_union (s1, s2), _intersect (s1, s2))
102+ end
103+ end
104+ return addcost (maxcost, zero (maxcost)) # add zero for type stability: Power -> Poly
105+ end
106+
107+ function _optimaltree (:: Type{T} , network, allindices, allcosts:: Vector{S} , initialcost:: C ;
108+ verbose:: Bool = false ) where {T,S,C}
100109 numindices = length (allindices)
101110 numtensors = length (network)
102111 indexsets = Array {T} (undef, numtensors)
@@ -156,7 +165,8 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi
156165 end
157166
158167 # run over currentcost
159- currentcost = initialcost
168+ maxcost = computemaxcost (allcosts, @view (indexsets[component]))
169+ currentcost = min (initialcost, maxcost)
160170 previouscost = zero (initialcost)
161171 while currentcost <= maxcost
162172 nextcost = maxcost
0 commit comments