diff --git a/dev/KrylovKit b/dev/KrylovKit new file mode 160000 index 000000000..8bccac88a --- /dev/null +++ b/dev/KrylovKit @@ -0,0 +1 @@ +Subproject commit 8bccac88a9474b47ce49bac72ae19b3806ce129f diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index f48ed4393..4b9dcd6b1 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -4,7 +4,7 @@ @non_differentiable TensorKit.isometry(args...) @non_differentiable TensorKit.unitary(args...) -function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...) +function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs...) function TensorMap_pullback(Δt) ∂d = convert(Array, unthunk(Δt)) return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... @@ -12,6 +12,78 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg return TensorMap(d, args...; kwargs...), TensorMap_pullback end +# these are not the conversion to/from array, but actually take in data parameters +# -- as a result, requires quantum dimensions to keep inner product the same: +# ⟨Δdata, ∂data⟩ = ⟨Δtensor, ∂tensor⟩ = ∑_c d_c ⟨Δtensor_c, ∂tensor_c⟩ +# ⟹ Δdata = d_c Δtensor_c +function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector, + V::TensorMapSpace) where {T} + t = TensorMap{T}(data, V) + P = ProjectTo(data) + function TensorMap_pullback(Δt_) + Δt = copy(unthunk(Δt_)) + for (c, b) in blocks(Δt) + scale!(b, dim(c)) + end + ∂data = P(Δt.data) + return NoTangent(), ∂data, NoTangent() + end + return t, TensorMap_pullback +end + +function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, args...; + kwargs...) + D = DiagonalTensorMap(data, args...; kwargs...) + P = ProjectTo(data) + function DiagonalTensorMap_pullback(Δt_) + # unclear if we're allowed to modify/take ownership of the input + Δt = copy(unthunk(Δt_)) + for (c, b) in blocks(Δt) + scale!(b, dim(c)) + end + ∂data = P(Δt.data) + return NoTangent(), ∂data, NoTangent() + end + return D, DiagonalTensorMap_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Symbol) + if prop === :data + function getdata_pullback(Δdata) + # unclear if we're allowed to modify/take ownership of the input + t′ = typeof(t)(copy(unthunk(Δdata)), t.space) + for (c, b) in blocks(t′) + scale!(b, inv(dim(c))) + end + return NoTangent(), t′, NoTangent() + end + return t.data, getdata_pullback + elseif prop === :space + return t.space, Returns((NoTangent(), ZeroTangent(), NoTangent())) + else + throw(ArgumentError("unknown property $prop")) + end +end + +function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap, + prop::Symbol) + if prop === :data + function getdata_pullback(Δdata) + # unclear if we're allowed to modify/take ownership of the input + t′ = typeof(t)(copy(unthunk(Δdata)), t.domain) + for (c, b) in blocks(t′) + scale!(b, inv(dim(c))) + end + return NoTangent(), t′, NoTangent() + end + return t.data, getdata_pullback + elseif prop === :domain + return t.domain, Returns((NoTangent(), ZeroTangent(), NoTangent())) + else + throw(ArgumentError("unknown property $prop")) + end +end + function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) copy_pullback(Δt) = NoTangent(), Δt return copy(t), copy_pullback diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 922900b6f..67dc5a980 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -113,7 +113,8 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap) return a_imag, imag_pullback end -function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorMap) +function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(exp), + A::AbstractTensorMap) domain(A) == codomain(A) || error("Exponential of a tensor only exist when domain == codomain.") P_A = ProjectTo(A) @@ -133,3 +134,21 @@ function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorM end return C, exp_pullback end + +# define rrules for matrix functions for DiagonalTensorMap, since they access data directly. +for f in + (:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt, + :log, :asin, :acos, :acosh, :atanh, :acoth) + f_pullback = Symbol(f, :_pullback) + @eval function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($f), + t::DiagonalTensorMap) + P = ProjectTo(t) # unsure if this is necessary, should already be in pullback + d, pullback = rrule_via_ad(cfg, broadcast, $f, t.data) + function $f_pullback(Δd_) + Δd = P(unthunk(Δd_)) + _, _, ∂data = pullback(Δd.data) + return NoTangent(), DiagonalTensorMap(∂data, t.domain) + end + return DiagonalTensorMap(d, t.domain), $f_pullback + end +end diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index f270e346a..2a9f11d57 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -32,3 +32,15 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N end return y end + +function (::ProjectTo{DiagonalTensorMap{T,S,A}})(x::AbstractTensorMap) where {T,S,A} + x isa DiagonalTensorMap{T,S,A} && return x + V = space(x, 1) + space(x) == (V ← V) || throw(SpaceMismatch()) + y = DiagonalTensorMap{T,S,A}(undef, V) + for (c, b) in blocks(y) + p = ProjectTo(b) + b .= p(block(x, c)) + end + return y +end diff --git a/ext/TensorKitFiniteDifferencesExt.jl b/ext/TensorKitFiniteDifferencesExt.jl index f62642560..63c9711e3 100644 --- a/ext/TensorKitFiniteDifferencesExt.jl +++ b/ext/TensorKitFiniteDifferencesExt.jl @@ -23,6 +23,14 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap) end FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t)) +function FiniteDifferences.to_vec(t::DiagonalTensorMap) + x_vec, back = to_vec(TensorMap(t)) + function DiagonalTensorMap_from_vec(x_vec) + return DiagonalTensorMap(back(x_vec)) + end + return x_vec, DiagonalTensorMap_from_vec +end + end # TODO: Investigate why the approach below doesn't work diff --git a/test/ad.jl b/test/ad.jl index 93251eb03..044e89534 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -3,6 +3,7 @@ using ChainRulesTestUtils using FiniteDifferences: FiniteDifferences using Random using LinearAlgebra +using Zygote const _repartition = @static if isdefined(Base, :get_extension) Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition @@ -15,6 +16,10 @@ end function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap) return randn!(similar(x)) end +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap) + V = x.domain + return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V) +end ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent() function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, expected::AbstractTensorMap, msg=""; kwargs...) @@ -152,6 +157,46 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); fkwargs=(; tol=Inf)) end + + test_rrule(Base.getproperty, T1, :data) + test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space) + test_rrule(Base.getproperty, T2, :data) + test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space) + end + + @timedtestset "Basic utility (DiagonalTensor)" begin + for v in V + rdim = reduceddim(v) + D1 = DiagonalTensorMap(randn(rdim), v) + D2 = DiagonalTensorMap(randn(rdim), v) + D = D1 + im * D2 + T1 = TensorMap(D1) + T2 = TensorMap(D2) + T = T1 + im * T2 + + # real -> real + P1 = ProjectTo(D1) + @test P1(D1) == D1 + @test P1(T1) == D1 + + # complex -> complex + P2 = ProjectTo(D) + @test P2(D) == D + @test P2(T) == D + + # real -> complex + @test P2(D1) == D1 + 0 * im * D1 + @test P2(T1) == D1 + 0 * im * D1 + + # complex -> real + @test P1(D) == D1 + @test P1(T) == D1 + + test_rrule(DiagonalTensorMap, D1.data, D1.domain) + test_rrule(DiagonalTensorMap, D.data, D.domain) + test_rrule(Base.getproperty, D, :data) + test_rrule(Base.getproperty, D1, :data) + end end @timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes @@ -196,6 +241,21 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(LinearAlgebra.dot, A, B) end + @timedtestset "Matrix functions ($T)" for T in eltypes + for f in (sqrt, exp) + check_inferred = false # !(T <: Real) # not type-stable for real functions + t1 = randn(T, V[1] ← V[1]) + t2 = randn(T, V[2] ← V[2]) + d = DiagonalTensorMap{T}(undef, V[1]) + (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) + d2 = DiagonalTensorMap{T}(undef, V[1]) + (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) + test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) + test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) + test_rrule(f, d; check_inferred, output_tangent=d2) + end + end + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes atol = precision(T) diff --git a/test/runtests.jl b/test/runtests.jl index 1f06191cc..d0cd9945b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,7 +60,8 @@ include("spaces.jl") include("tensors.jl") include("diagonal.jl") include("planar.jl") -if !(Sys.isapple()) # TODO: remove once we know why this is so slow on macOS +# TODO: remove once we know AD is slow on macOS CI +if !(Sys.isapple() && get(ENV, "CI", "false") == "true") include("ad.jl") end include("bugfixes.jl")