From 537ba159876bd9378ebb683023fd7dcf6d0f9257 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 22 Sep 2025 12:02:54 +0200 Subject: [PATCH 1/4] Include alpha for determining scalartype of linear combination --- src/indexnotation/instantiators.jl | 6 +++--- test/tensor.jl | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/indexnotation/instantiators.jl b/src/indexnotation/instantiators.jl index 9bd5b44..1784734 100644 --- a/src/indexnotation/instantiators.jl +++ b/src/indexnotation/instantiators.jl @@ -167,9 +167,9 @@ function instantiate_linearcombination( ) out = Expr(:block) if alloc ∈ (NewTensor, TemporaryTensor) - if scaltype === nothing - scaltype = instantiate_scalartype(ex) - end + scaltype = @something( + scaltype, instantiate_scalartype(α === One() ? ex : Expr(:call, :*, α, ex)) + ) push!( out.args, instantiate(dst, β, ex.args[2], α, leftind, rightind, alloc, scaltype) diff --git a/test/tensor.jl b/test/tensor.jl index 9ba9400..7e97c36 100644 --- a/test/tensor.jl +++ b/test/tensor.jl @@ -581,4 +581,15 @@ end @test isblascontractable(pA, p) @test isblascontractable(conj(pA), p) end + + @testset "Issue 220" begin + A = rand(2, 2) + B = rand(2, 2) + C = rand(2, 2) + D = rand(2, 2) + c = 1im + @tensor E[a; c] := c * (A[a b] * B[b c] + C[a b] * D[b c]) + @test scalartype(E) == ComplexF64 + @test E ≈ c * (A * B + C * D) + end end From 48649716d38fa5bdc253cb6d4a2b82119ff6e5df Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 23 Sep 2025 07:38:44 -0400 Subject: [PATCH 2/4] Bump v5.3.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e475c1a..1265375 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorOperations" uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] -version = "5.3.0" +version = "5.3.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" From 039d72d60edc22a329ad08febf7794fced1e56da Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 23 Sep 2025 18:44:49 -0400 Subject: [PATCH 3/4] some unification of scaltype determination --- src/indexnotation/instantiators.jl | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/indexnotation/instantiators.jl b/src/indexnotation/instantiators.jl index 1784734..2920447 100644 --- a/src/indexnotation/instantiators.jl +++ b/src/indexnotation/instantiators.jl @@ -124,12 +124,9 @@ function instantiate_generaltensor( if alloc ∈ (NewTensor, TemporaryTensor) TC = gensym("T_" * string(dst)) istemporary = Val(alloc === TemporaryTensor) - if scaltype === nothing - TCval = α === One() ? instantiate_scalartype(src) : - instantiate_scalartype(Expr(:call, :*, α, src)) - else - TCval = scaltype - end + TCval = @something( + scaltype, instantiate_scalartype(α === One() ? src : Expr(:call, :*, α, src)) + ) push!(out.args, Expr(:(=), TC, TCval)) push!( out.args, @@ -275,18 +272,15 @@ function instantiate_contraction( end if alloc ∈ (NewTensor, TemporaryTensor) TCsym = gensym("T_" * string(dst)) - if scaltype === nothing - Atype = instantiate_scalartype(A) - Btype = instantiate_scalartype(B) - TCval = Expr(:call, :promote_contract, Atype, Btype) - if α !== One() - TCval = Expr( - :call, :(Base.promote_op), :*, instantiate_scalartype(α), TCval - ) + TCval = @something( + scaltype, + begin + TA = instantiate_scalartype(A) + TB = instantiate_scalartype(B) + TAB = :(promote_contract($TA, $TB)) + α === One() ? TAB : :(Base.promote_op(*, $(instantiate_scalartype(α)), $TAB)) end - else - TCval = scaltype - end + ) istemporary = Val(alloc === TemporaryTensor) initC = Expr( :block, Expr(:(=), TCsym, TCval), From 573529f7f47497b399bd71aa4e45b945d0c58594 Mon Sep 17 00:00:00 2001 From: Jutho Date: Wed, 24 Sep 2025 23:26:05 +0200 Subject: [PATCH 4/4] Some more consistency changes --- src/indexnotation/instantiators.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/indexnotation/instantiators.jl b/src/indexnotation/instantiators.jl index 2920447..c05b351 100644 --- a/src/indexnotation/instantiators.jl +++ b/src/indexnotation/instantiators.jl @@ -122,15 +122,15 @@ function instantiate_generaltensor( β = βsym end if alloc ∈ (NewTensor, TemporaryTensor) - TC = gensym("T_" * string(dst)) + TCsym = gensym("T_" * string(dst)) istemporary = Val(alloc === TemporaryTensor) TCval = @something( scaltype, instantiate_scalartype(α === One() ? src : Expr(:call, :*, α, src)) ) - push!(out.args, Expr(:(=), TC, TCval)) + push!(out.args, Expr(:(=), TCsym, TCval)) push!( out.args, - Expr(:(=), dst, :(tensoralloc_add($TC, $src, $p, $conj, $istemporary))) + Expr(:(=), dst, :(tensoralloc_add($TCsym, $src, $p, $conj, $istemporary))) ) end