|
4 | 4 | @non_differentiable TensorKit.isometry(args...) |
5 | 5 | @non_differentiable TensorKit.unitary(args...) |
6 | 6 |
|
7 | | -function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...) |
| 7 | +function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs...) |
8 | 8 | function TensorMap_pullback(Δt) |
9 | 9 | ∂d = convert(Array, unthunk(Δt)) |
10 | 10 | return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... |
11 | 11 | end |
12 | 12 | return TensorMap(d, args...; kwargs...), TensorMap_pullback |
13 | 13 | end |
14 | 14 |
|
| 15 | +# these are not the conversion to/from array, but actually take in data parameters |
| 16 | +# -- as a result, requires quantum dimensions to keep inner product the same: |
| 17 | +# ⟨Δdata, ∂data⟩ = ⟨Δtensor, ∂tensor⟩ = ∑_c d_c ⟨Δtensor_c, ∂tensor_c⟩ |
| 18 | +# ⟹ Δdata = d_c Δtensor_c |
| 19 | +function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector, |
| 20 | + V::TensorMapSpace) where {T} |
| 21 | + t = TensorMap{T}(data, V) |
| 22 | + P = ProjectTo(data) |
| 23 | + function TensorMap_pullback(Δt_) |
| 24 | + Δt = copy(unthunk(Δt_)) |
| 25 | + for (c, b) in blocks(Δt) |
| 26 | + scale!(b, dim(c)) |
| 27 | + end |
| 28 | + ∂data = P(Δt.data) |
| 29 | + return NoTangent(), ∂data, NoTangent() |
| 30 | + end |
| 31 | + return t, TensorMap_pullback |
| 32 | +end |
| 33 | + |
| 34 | +function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, args...; |
| 35 | + kwargs...) |
| 36 | + D = DiagonalTensorMap(data, args...; kwargs...) |
| 37 | + P = ProjectTo(data) |
| 38 | + function DiagonalTensorMap_pullback(Δt_) |
| 39 | + # unclear if we're allowed to modify/take ownership of the input |
| 40 | + Δt = copy(unthunk(Δt_)) |
| 41 | + for (c, b) in blocks(Δt) |
| 42 | + scale!(b, dim(c)) |
| 43 | + end |
| 44 | + ∂data = P(Δt.data) |
| 45 | + return NoTangent(), ∂data, NoTangent() |
| 46 | + end |
| 47 | + return D, DiagonalTensorMap_pullback |
| 48 | +end |
| 49 | + |
| 50 | +function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Symbol) |
| 51 | + if prop === :data |
| 52 | + function getdata_pullback(Δdata) |
| 53 | + # unclear if we're allowed to modify/take ownership of the input |
| 54 | + t′ = typeof(t)(copy(unthunk(Δdata)), t.space) |
| 55 | + for (c, b) in blocks(t′) |
| 56 | + scale!(b, inv(dim(c))) |
| 57 | + end |
| 58 | + return NoTangent(), t′, NoTangent() |
| 59 | + end |
| 60 | + return t.data, getdata_pullback |
| 61 | + elseif prop === :space |
| 62 | + return t.space, Returns((NoTangent(), ZeroTangent(), NoTangent())) |
| 63 | + else |
| 64 | + throw(ArgumentError("unknown property $prop")) |
| 65 | + end |
| 66 | +end |
| 67 | + |
| 68 | +function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap, |
| 69 | + prop::Symbol) |
| 70 | + if prop === :data |
| 71 | + function getdata_pullback(Δdata) |
| 72 | + # unclear if we're allowed to modify/take ownership of the input |
| 73 | + t′ = typeof(t)(copy(unthunk(Δdata)), t.domain) |
| 74 | + for (c, b) in blocks(t′) |
| 75 | + scale!(b, inv(dim(c))) |
| 76 | + end |
| 77 | + return NoTangent(), t′, NoTangent() |
| 78 | + end |
| 79 | + return t.data, getdata_pullback |
| 80 | + elseif prop === :domain |
| 81 | + return t.domain, Returns((NoTangent(), ZeroTangent(), NoTangent())) |
| 82 | + else |
| 83 | + throw(ArgumentError("unknown property $prop")) |
| 84 | + end |
| 85 | +end |
| 86 | + |
15 | 87 | function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) |
16 | 88 | copy_pullback(Δt) = NoTangent(), Δt |
17 | 89 | return copy(t), copy_pullback |
|
0 commit comments