Skip to content

Commit 413b2aa

Browse files
committed
handle cases with more tensors
1 parent 978de1e commit 413b2aa

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/indexnotation/optimaltree.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ function optimaltree(network, optdata::Dict; verbose::Bool=false)
44
numindices = length(allindices)
55
costtype = valtype(optdata)
66
allcosts = costtype[get(optdata, i, one(costtype)) for i in allindices]
7-
maxcost = addcost(mulcost(reduce(mulcost, allcosts; init=one(costtype)),
8-
maximum(allcosts)), one(costtype))
9-
# add one for type stability: Power -> Poly
10-
# and for dealing with cases where all sizes are 1 -> maxcost = 2
7+
maxcost = max(mulcost(reduce(mulcost, allcosts; init=one(costtype)),
8+
maximum(allcosts)), costtype(length(network) - 1))
9+
# add zero for type stability: Power -> Poly
10+
maxcost = addcost(maxcost, zero(costtype))
11+
1112
tensorcosts = Vector{costtype}(undef, numtensors)
1213
for k in 1:numtensors
1314
tensorcosts[k] = mapreduce(i -> get(optdata, i, one(costtype)), mulcost, network[k];

test/tensoropt.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,13 @@ end
5858

5959
@testset "Issue #206" begin
6060
# https://github.com/Jutho/TensorOperations.jl/issues/206
61-
network = [[:a, :b], [:a], [:b]]
62-
opt_data = Dict(:a => 1, :b => 1)
61+
network = [[Symbol(1), Symbol(2)], [Symbol(1)], [Symbol(2)]]
62+
opt_data = Dict(Symbol(1) => 1, Symbol(2) => 1)
6363
tree, cost = TensorOperations.optimaltree(network, opt_data)
6464
@test cost == 2
65+
66+
network = network = [[:a, :b, :c], [:a], [:b], [:c]]
67+
opt_data = Dict(:a => 1, :b => 1, :c => 1)
68+
tree, cost = TensorOperations.optimaltree(network, opt_data)
69+
@test cost == 3
6570
end

0 commit comments

Comments
 (0)