-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathconstructors.jl
More file actions
71 lines (64 loc) · 2.91 KB
/
constructors.jl
File metadata and controls
71 lines (64 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.id(args...)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)
function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...)
function TensorMap_pullback(Δt)
∂d = convert(Array, unthunk(Δt))
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
return TensorMap(d, args...; kwargs...), TensorMap_pullback
end
function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
end
function ChainRulesCore.rrule(::typeof(TensorKit.copy_oftype), t::AbstractTensorMap,
T::Type{<:Number})
project = ProjectTo(t)
copy_oftype_pullback(Δt) = NoTangent(), project(unthunk(Δt)), NoTangent()
return TensorKit.copy_oftype(t, T), copy_oftype_pullback
end
function ChainRulesCore.rrule(::typeof(TensorKit.permutedcopy_oftype), t::AbstractTensorMap,
T::Type{<:Number}, p::Index2Tuple)
project = ProjectTo(t)
function permutedcopy_oftype_pullback(Δt)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), t)
return NoTangent(), project(TensorKit.permute(unthunk(Δt), invp)), NoTangent(),
NoTangent()
end
return TensorKit.permutedcopy_oftype(t, T, p), permutedcopy_oftype_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
# use constructor to (unconditionally) project back onto symmetric subspace
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t); tol=Inf)
return NoTangent(), NoTangent(), ∂t
end
return A, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c′)
c = unthunk(c′)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))
end
end
return out, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end
function ChainRulesCore.rrule(T::Type{<:TensorKit.AdjointTensorMap}, t::AbstractTensorMap)
return T(t), Δt -> (NoTangent(), adjoint(unthunk(Δt)))
end