Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2b79107
adds `ProjectTo` for `DiagonalTensorMap`
ebelnikola Jan 19, 2025
f275175
adds an `rrule` for `DiagonalTensorMap` constructor
ebelnikola Jan 19, 2025
2cf4cb4
Corrects bug in the DiagonalTensorMap rrule, adds
ebelnikola Jan 20, 2025
c43535f
@test missing in the constructor test added...
ebelnikola Jan 20, 2025
4653113
wait no, @test did not belong there
ebelnikola Jan 20, 2025
c81e43c
Update ext/TensorKitChainRulesCoreExt/utility.jl
ebelnikola Jan 20, 2025
41d1113
mixed type tests for ProjectTo
ebelnikola Jan 20, 2025
53b399c
+ rrule test on complex tensors.
ebelnikola Jan 20, 2025
8b51b8b
correct data length for DiagonalTensor in tests
ebelnikola Jan 20, 2025
2905c0b
correct data length in DiagonalTensorMap for random tnagents
ebelnikola Jan 20, 2025
79a38b6
Comment on the test failure
ebelnikola Jan 20, 2025
689b4ea
Jutho's corrections
ebelnikola Jan 23, 2025
e1fe3be
Add `DiagonalTensorMap(::AbstractTensorMap)`
lkdvos Jan 23, 2025
ccb7c93
Specialize `to_vec(::DiagonalTensorMap)`
lkdvos Jan 23, 2025
f8ca1fd
Add rrules matrix functions
lkdvos Jan 23, 2025
98dea08
Add tests AD of matrixfunctions
lkdvos Jan 23, 2025
58904f7
Merge branch 'master' into ProjectTo-for-DiagonalTensorMap
lkdvos Feb 5, 2025
5eec371
Remove duplicate methods
lkdvos Feb 6, 2025
9e834c4
disable broken tests
lkdvos Feb 6, 2025
a5ba340
Fix CI check
lkdvos Feb 6, 2025
3e45ff2
Adapt rrules for constructors and getproperty to include qdims
lkdvos Feb 7, 2025
f9ed03e
exchange sqrt and invsqrt in hope of fixing without thinking
lkdvos Feb 7, 2025
b195929
Merge branch 'master' into ProjectTo-for-DiagonalTensorMap
lkdvos Feb 7, 2025
643aba7
Actually think to fix the problem
lkdvos Feb 7, 2025
06d3784
Simplify positive data generation
lkdvos Feb 7, 2025
2732f66
simplify CI detection
lkdvos Feb 7, 2025
455bcc2
Merge branch 'master' into ProjectTo-for-DiagonalTensorMap
lkdvos Feb 7, 2025
8c08866
Fix bad merge
lkdvos Feb 7, 2025
618aaad
uncomment non-ad tests
Jutho Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg
return TensorMap(d, args...; kwargs...), TensorMap_pullback
end

function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...;
kwargs...)
D = DiagonalTensorMap(d, args...; kwargs...)
project_D = ProjectTo(D)
function DiagonalTensorMap_pullback(Δt)
∂d = project_D(unthunk(Δt)).data
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation slightly surprises me, I would have expected the projection to be based off d. In the end, these things probably boil down to the same thing?

Suggested change
D = DiagonalTensorMap(d, args...; kwargs...)
project_D = ProjectTo(D)
function DiagonalTensorMap_pullback(Δt)
∂d = project_D(unthunk(Δt)).data
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
project_d = ProjectTo(d)
function DiagonalTensorMap_pullback(Δt)
∂d = project_d(unthunk(Δt).data)
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
D = DiagonalTensorMap(d, args...; kwargs...)

Copy link
Contributor Author

@ebelnikola ebelnikola Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me this is not the same. The issue here, is that sometimes \Delta t may be of some non-diagonal type. I expect that in this situation your version will return incorrect tangent (with a lot of zeros from off-diagonal parts of \Delta t). Am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, I think I missed this.
Somehow, I was expecting the input to the pullback to always be a DiagonalTensorMap, since this really is a constructor, and if not, there's probably a projector missing in some other rrule...
I also found some comments in the rrules for Diagonal where a similar discussion is taking place.
Long story short though, it seems like their solution is more similar to yours, so that's definitely okay for me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also undecided between the two options. I agree that in most cases, e.g. tensor contractions, there should already have been a ProjectTo(D) being called on the adjoint variable Δt before it enters this rrule, and therefore not explicitly having project_d(Δt) in this rule might serve as a way to find missing projectors elsewhere.

On the other hand, from a user perspective, I can also see the advantage of just having this in project_D in here for safety.

While in principle that is independent of still having to call project_d = ProjectTo(d) on the output ∂d, it is probably true that project_d will not be doing anything (will act as the identity) if we already have projected Δt to a diagonal tensor with the right scalar type and storage type.

return D, DiagonalTensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
Expand Down
12 changes: 12 additions & 0 deletions ext/TensorKitChainRulesCoreExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N
end
return y
end

function (::ProjectTo{T1})(x::T2) where {S,NumType,StorType,
T1<:DiagonalTensorMap{NumType,S,StorType},
T2<:AbstractTensorMap{<:Any,S,1,1}}
T1 === T2 && return x
y = DiagonalTensorMap{NumType,S,StorType}(undef, space(x, 1))
for (c, b) in blocks(y)
p = ProjectTo(b)
b .= p(block(x, c))
end
return y
end
17 changes: 17 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ end
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
return randn!(similar(x))
end
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
return DiagonalTensorMap(randn(eltype(x), dim(x.domain)), x.domain)
end
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
expected::AbstractTensorMap, msg=""; kwargs...)
Expand Down Expand Up @@ -144,6 +147,20 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
fkwargs=(; tol=Inf))
end

@timedtestset "Basic utility (DiagonalTensor)" begin
for NumType in [Float64, ComplexF64]
for v in V
T1 = DiagonalTensorMap(randn(NumType, dim(v)), v)
T2 = TensorMap(T1)

P1 = ProjectTo(T1)
@test P1(T2) == T1

test_rrule(DiagonalTensorMap, T1.data, T1.domain)
end
end
end

@timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
B = randn(T, space(A))
Expand Down