Skip to content

Commit f275175

Browse files
committed
adds an rrule for DiagonalTensorMap constructor
1 parent 2b79107 commit f275175

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

ext/TensorKitChainRulesCoreExt/constructors.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg
1212
return TensorMap(d, args...; kwargs...), TensorMap_pullback
1313
end
1414

15+
function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; kwargs...)
16+
D=TensorMap(d, args...; kwargs...)
17+
project_D=ProjectTo(D)
18+
function DiagonalTensorMap_pullback(Δt)
19+
∂d = project_D(unthunk(Δt)).data
20+
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
21+
end
22+
return D, DiagonalTensorMap_pullback
23+
end
24+
1525
function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
1626
copy_pullback(Δt) = NoTangent(), Δt
1727
return copy(t), copy_pullback

0 commit comments

Comments
 (0)