Skip to content

Commit a847f15

Browse files
committed
Refactor and improve
1 parent 413b2aa commit a847f15

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

src/indexnotation/optimaltree.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,23 @@ function optimaltree(network, optdata::Dict; verbose::Bool=false)
44
numindices = length(allindices)
55
costtype = valtype(optdata)
66
allcosts = costtype[get(optdata, i, one(costtype)) for i in allindices]
7-
maxcost = max(mulcost(reduce(mulcost, allcosts; init=one(costtype)),
8-
maximum(allcosts)), costtype(length(network) - 1))
9-
# add zero for type stability: Power -> Poly
10-
maxcost = addcost(maxcost, zero(costtype))
117

128
tensorcosts = Vector{costtype}(undef, numtensors)
139
for k in 1:numtensors
1410
tensorcosts[k] = mapreduce(i -> get(optdata, i, one(costtype)), mulcost, network[k];
1511
init=one(costtype))
1612
end
17-
initialcost = min(maxcost,
18-
addcost(mulcost(maximum(tensorcosts), minimum(tensorcosts)),
19-
zero(costtype))) # just some arbitrary guess
13+
initialcost = addcost(mulcost(maximum(tensorcosts), minimum(tensorcosts)),
14+
zero(costtype)) # just some arbitrary guess
2015

2116
if numindices <= 32
22-
return _optimaltree(UInt32, network, allindices, allcosts, initialcost, maxcost;
23-
verbose=verbose)
17+
return _optimaltree(UInt32, network, allindices, allcosts, initialcost; verbose)
2418
elseif numindices <= 64
25-
return _optimaltree(UInt64, network, allindices, allcosts, initialcost, maxcost;
26-
verbose=verbose)
19+
return _optimaltree(UInt64, network, allindices, allcosts, initialcost; verbose)
2720
elseif numindices <= 128 && !(@static Int == Int32 && Sys.iswindows() ? true : false)
28-
return _optimaltree(UInt128, network, allindices, allcosts, initialcost, maxcost;
29-
verbose=verbose)
21+
return _optimaltree(UInt128, network, allindices, allcosts, initialcost; verbose)
3022
else
31-
return _optimaltree(BitVector, network, allindices, allcosts, initialcost, maxcost;
32-
verbose=verbose)
23+
return _optimaltree(BitVector, network, allindices, allcosts, initialcost; verbose)
3324
end
3425
end
3526

@@ -98,8 +89,23 @@ function computecost(allcosts, ind1::BitSet, ind2::BitSet)
9889
return cost
9990
end
10091

101-
function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C,
102-
maxcost::C; verbose::Bool=false) where {T,S,C}
92+
function computemaxcost(allcosts, indexsets)
93+
if length(indexsets) == 1
94+
maxcost = one(etltype(allcosts))
95+
else
96+
maxcost = zero(eltype(allcosts))
97+
s1 = indexsets[1]
98+
for n in 2:length(indexsets)
99+
s2 = indexsets[n]
100+
maxcost = addcost(maxcost, computecost(allcosts, s1, s2))
101+
s1 = _setdiff(_union(s1, s2), _intersect(s1, s2))
102+
end
103+
end
104+
return addcost(maxcost, zero(maxcost)) # add zero for type stability: Power -> Poly
105+
end
106+
107+
function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C;
108+
verbose::Bool=false) where {T,S,C}
103109
numindices = length(allindices)
104110
numtensors = length(network)
105111
indexsets = Array{T}(undef, numtensors)
@@ -159,7 +165,8 @@ function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initi
159165
end
160166

161167
# run over currentcost
162-
currentcost = initialcost
168+
maxcost = computemaxcost(allcosts, @view(indexsets[component]))
169+
currentcost = min(initialcost, maxcost)
163170
previouscost = zero(initialcost)
164171
while currentcost <= maxcost
165172
nextcost = maxcost

0 commit comments

Comments
 (0)