Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Test = "1"
TestExtras = "0.2,0.3"
TupleTools = "1.1"
VectorInterface = "0.4, 0.5"
Zygote = "0.7"
julia = "1.10"

[extras]
Expand All @@ -53,6 +54,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"]
test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
17 changes: 17 additions & 0 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
return copy(t), copy_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.copy_oftype), t::AbstractTensorMap,
T::Type{<:Number})
project = ProjectTo(t)
copy_oftype_pullback(Δt) = NoTangent(), project(unthunk(Δt)), NoTangent()
return TensorKit.copy_oftype(t, T), copy_oftype_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.permutedcopy_oftype), t::AbstractTensorMap,
T::Type{<:Number}, p::Index2Tuple)
project = ProjectTo(t)
function permutedcopy_oftype_pullback(Δt)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), t)
return project(TensorKit.permutedcopy_oftype(unthunk(Δt), scalartype(t), invp))
end
return TensorKit.permutedcopy_oftype(t, T, p), permutedcopy_oftype_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::AbstractTensorMap)
A = convert(T, t)
Expand Down
3 changes: 2 additions & 1 deletion ext/TensorKitChainRulesCoreExt/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
Ũ, Σ̃, Ṽ⁺ = U, Σ, V⁺
end

function tsvd!_pullback((ΔU, ΔΣ, ΔV⁺, Δϵ))
function tsvd!_pullback(ΔUSVϵ)
ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ)
Δt = similar(t)
for (c, b) in blocks(Δt)
Uc, Σc, V⁺c = block(U, c), block(Σ, c), block(V⁺, c)
Expand Down
18 changes: 9 additions & 9 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,39 +269,39 @@ function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
end

function tsvd(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return tsvd!(tcopy; kwargs...)
end
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return leftorth!(tcopy; alg=alg, kwargs...)
end
function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return rightorth!(tcopy; alg=alg, kwargs...)
end
function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return leftnull!(tcopy; alg=alg, kwargs...)
end
function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return rightnull!(tcopy; alg=alg, kwargs...)
end
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return eigen!(tcopy; kwargs...)
end
function eig(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return eig!(tcopy; kwargs...)
end
function eigh(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return eigh!(tcopy; kwargs...)
end
function LinearAlgebra.isposdef(t::AbstractTensorMap)
tcopy = copy!(similar(t, float(scalartype(t))), t)
tcopy = copy_oftype(t, float(scalartype(t)))
return isposdef!(tcopy)
end

Expand Down
2 changes: 2 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),

test_rrule(copy, T1)
test_rrule(copy, T2)
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4)))

test_rrule(convert, Array, T1)
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
Expand Down
28 changes: 28 additions & 0 deletions test/bugfixes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,32 @@
@test storagetype(t5) == Vector{Float64}
tensorfree!(t2)
end

@testset "Issue #201" begin
function f(A::AbstractTensorMap)
U, S, V, = tsvd(A)
return tr(S)
end
function f(A::AbstractMatrix)
S = LinearAlgebra.svdvals(A)
return sum(S)
end
A₀ = randn(Z2Space(4, 4) ← Z2Space(4, 4))
grad1, = Zygote.gradient(f, A₀)
grad2, = Zygote.gradient(f, convert(Array, A₀))
@test convert(Array, grad1) ≈ grad2

function g(A::AbstractTensorMap)
U, S, V, = tsvd(A)
return tr(U * V)
end
function g(A::AbstractMatrix)
U, S, V, = LinearAlgebra.svd(A)
return tr(U * V')
end
B₀ = randn(ComplexSpace(4) ← ComplexSpace(4))
grad3, = Zygote.gradient(g, B₀)
grad4, = Zygote.gradient(g, convert(Array, B₀))
@test convert(Array, grad3) ≈ grad4
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Base.Iterators: take, product
# using SUNRepresentations: SUNIrrep
# const SU3Irrep = SUNIrrep{3}
using LinearAlgebra: LinearAlgebra
using Zygote: Zygote

const TK = TensorKit

Expand Down