diff --git a/Project.toml b/Project.toml index e475c1a..1265375 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorOperations" uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] -version = "5.3.0" +version = "5.3.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/indexnotation/instantiators.jl b/src/indexnotation/instantiators.jl index 9bd5b44..c05b351 100644 --- a/src/indexnotation/instantiators.jl +++ b/src/indexnotation/instantiators.jl @@ -122,18 +122,15 @@ function instantiate_generaltensor( β = βsym end if alloc ∈ (NewTensor, TemporaryTensor) - TC = gensym("T_" * string(dst)) + TCsym = gensym("T_" * string(dst)) istemporary = Val(alloc === TemporaryTensor) - if scaltype === nothing - TCval = α === One() ? instantiate_scalartype(src) : - instantiate_scalartype(Expr(:call, :*, α, src)) - else - TCval = scaltype - end - push!(out.args, Expr(:(=), TC, TCval)) + TCval = @something( + scaltype, instantiate_scalartype(α === One() ? src : Expr(:call, :*, α, src)) + ) + push!(out.args, Expr(:(=), TCsym, TCval)) push!( out.args, - Expr(:(=), dst, :(tensoralloc_add($TC, $src, $p, $conj, $istemporary))) + Expr(:(=), dst, :(tensoralloc_add($TCsym, $src, $p, $conj, $istemporary))) ) end @@ -167,9 +164,9 @@ function instantiate_linearcombination( ) out = Expr(:block) if alloc ∈ (NewTensor, TemporaryTensor) - if scaltype === nothing - scaltype = instantiate_scalartype(ex) - end + scaltype = @something( + scaltype, instantiate_scalartype(α === One() ? ex : Expr(:call, :*, α, ex)) + ) push!( out.args, instantiate(dst, β, ex.args[2], α, leftind, rightind, alloc, scaltype) @@ -275,18 +272,15 @@ function instantiate_contraction( end if alloc ∈ (NewTensor, TemporaryTensor) TCsym = gensym("T_" * string(dst)) - if scaltype === nothing - Atype = instantiate_scalartype(A) - Btype = instantiate_scalartype(B) - TCval = Expr(:call, :promote_contract, Atype, Btype) - if α !== One() - TCval = Expr( - :call, :(Base.promote_op), :*, instantiate_scalartype(α), TCval - ) + TCval = @something( + scaltype, + begin + TA = instantiate_scalartype(A) + TB = instantiate_scalartype(B) + TAB = :(promote_contract($TA, $TB)) + α === One() ? TAB : :(Base.promote_op(*, $(instantiate_scalartype(α)), $TAB)) end - else - TCval = scaltype - end + ) istemporary = Val(alloc === TemporaryTensor) initC = Expr( :block, Expr(:(=), TCsym, TCval), diff --git a/test/tensor.jl b/test/tensor.jl index 9ba9400..7e97c36 100644 --- a/test/tensor.jl +++ b/test/tensor.jl @@ -581,4 +581,15 @@ end @test isblascontractable(pA, p) @test isblascontractable(conj(pA), p) end + + @testset "Issue 220" begin + A = rand(2, 2) + B = rand(2, 2) + C = rand(2, 2) + D = rand(2, 2) + c = 1im + @tensor E[a; c] := c * (A[a b] * B[b c] + C[a b] * D[b c]) + @test scalartype(E) == ComplexF64 + @test E ≈ c * (A * B + C * D) + end end