Skip to content

Commit a61a67b

Browse files
authored
Fix VectorInterface.scalartype for arrays of ITensors (#1614)
1 parent 37d1570 commit a61a67b

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensors"
22
uuid = "9136182c-28ba-11e9-034c-db9fb085ebd5"
33
authors = ["Matthew Fishman <[email protected]>", "Miles Stoudenmire <[email protected]>"]
4-
version = "0.7.12"
4+
version = "0.7.13"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ITensorsVectorInterfaceExt/ITensorsVectorInterfaceExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ function VectorInterface.scalartype(a::ITensor)
5858
return ITensors.scalartype(a)
5959
end
6060

61+
# Circumvent issue that `VectorInterface.jl` computes
62+
# the scalartype in the type domain, which isn't known
63+
# for ITensors.
64+
function VectorInterface.scalartype(a::AbstractArray{ITensor})
65+
# Like the implementation of `LinearAlgebra.promote_leaf_eltypes`:
66+
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/e7da19f2764ba36bd0a9eb8ec67dddce19d87114/src/generic.jl#L1933
67+
return mapreduce(VectorInterface.scalartype, promote_type, a; init=Bool)
68+
end
69+
6170
function VectorInterface.scale(a::ITensor, α::Number)
6271
return a * α
6372
end

test/ext/ITensorsVectorInterfaceExt/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@eval module $(gensym())
2-
using ITensors: Index, dag, inds, random_itensor
2+
using ITensors: ITensor, Index, dag, inds, random_itensor
33
using Test: @test, @testset
44
using VectorInterface:
55
add,
@@ -68,6 +68,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
6868
# scalartype
6969
@test scalartype(a) === elt
7070
@test scalartype(b) === elt
71+
@test scalartype([a, b]) === elt
72+
@test scalartype([a, random_itensor(Float32, i, j)]) === elt
73+
@test scalartype(ITensor[]) === Bool
7174

7275
# scale
7376
@test scale(a, α) α * a

0 commit comments

Comments
 (0)