Skip to content

Commit 52e7eaf

Browse files
committed
Now with tangents
1 parent fb52897 commit 52e7eaf

File tree

18 files changed

+188
-88
lines changed

18 files changed

+188
-88
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ jobs:
3131
- tensors
3232
- other
3333
- mooncake
34-
- enzyme
34+
- enzyme/factorizations
35+
- enzyme/linalg
36+
- enzyme/tensoroperations
37+
- enzyme/vectorinterface
38+
- enzyme/indexmanipulations
3539
- chainrules
3640
os:
3741
- ubuntu-latest

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2222
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2424
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
25+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
2526
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2627
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2728
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
@@ -31,6 +32,7 @@ TensorKitAdaptExt = "Adapt"
3132
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3233
TensorKitChainRulesCoreExt = "ChainRulesCore"
3334
TensorKitEnzymeExt = "Enzyme"
35+
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
3436
TensorKitFiniteDifferencesExt = "FiniteDifferences"
3537
TensorKitMooncakeExt = "Mooncake"
3638

ext/TensorKitEnzymeExt/factorizations.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ for (f, pb) in (
2828
alg::Const,
2929
) where {RT}
3030
ret = $f(A.val, alg.val)
31-
dret = make_zero(ret)
32-
cache = (ret, dret)
33-
return EnzymeRules.AugmentedReturn(ret, dret, cache)
31+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
32+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
33+
cache = (ret, shadow)
34+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
3435
end
3536
function EnzymeRules.reverse(
3637
config::EnzymeRules.RevConfigWidth{1},
@@ -40,8 +41,7 @@ for (f, pb) in (
4041
A::Annotation{<:AbstractTensorMap},
4142
alg::Const,
4243
) where {RT}
43-
ret, dret = cache
44-
$pb(A.dval, A.val, ret, dret)
44+
!isa(A, Const) && $pb(A.dval, A.val, cache...)
4545
return (nothing, nothing)
4646
end
4747
end
@@ -57,9 +57,10 @@ for f in (:svd_compact, :svd_full)
5757
alg::Const,
5858
) where {RT}
5959
USVᴴ = $f(A.val, alg.val)
60-
dUSVᴴ = make_zero(USVᴴ)
61-
cache = (USVᴴ, dUSVᴴ)
62-
return EnzymeRules.AugmentedReturn(USVᴴ, dUSVᴴ, cache)
60+
primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing
61+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing
62+
cache = (USVᴴ, shadow)
63+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
6364
end
6465
function EnzymeRules.reverse(
6566
config::EnzymeRules.RevConfigWidth{1},
@@ -69,8 +70,7 @@ for f in (:svd_compact, :svd_full)
6970
A::Annotation{<:AbstractTensorMap},
7071
alg::Const,
7172
) where {RT}
72-
USVᴴ, dUSVᴴ = cache
73-
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
73+
!isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...)
7474
return (nothing, nothing)
7575
end
7676
end

ext/TensorKitEnzymeExt/indexmanipulations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ function EnzymeRules.augmented_primal(
9090
β::Annotation{<:Number},
9191
ba::Const...
9292
) where {RT}
93-
C_cache = !isa(β, Const) ? copy(C.val) : nothing
94-
A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
93+
C_cache = !isa(β, Const) ? deepcopy(C.val) : nothing
94+
A_cache = EnzymeRules.overwritten(config)[3] ? deepcopy(A.val) : nothing
9595
# if we need to compute Δa, it is faster to allocate an intermediate braided A
9696
# and store that instead of repeating the permutation in the pullback each time.
9797
# effectively, we replace `add_permute` by `add ∘ permute`.

ext/TensorKitEnzymeExt/linalg.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,47 @@ function EnzymeRules.reverse(
110110
)
111111
return (nothing,)
112112
end
113-
113+
function EnzymeRules.augmented_primal(
114+
config::EnzymeRules.RevConfigWidth{1},
115+
func::Const{typeof(norm)},
116+
::Type{RT},
117+
A::Annotation{<:AbstractTensorMap},
118+
p::Const{<:Real},
119+
) where {RT}
120+
ret = func.val(A.val, p.val)
121+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
122+
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
123+
cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
124+
cache = (ret, cacheA)
125+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
126+
end
127+
function EnzymeRules.reverse(
128+
config::EnzymeRules.RevConfigWidth{1},
129+
func::Const{typeof(norm)},
130+
dret::Active,
131+
cache,
132+
A::Annotation{<:AbstractTensorMap},
133+
p::Const{<:Real},
134+
)
135+
n, cacheA = cache
136+
Δn = dret.val
137+
Aval = something(cacheA, A.val)
138+
if !isa(A, Const)
139+
x = (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
140+
add!(A.dval, A.val, x)
141+
end
142+
return (nothing, nothing)
143+
end
144+
function EnzymeRules.reverse(
145+
config::EnzymeRules.RevConfigWidth{1},
146+
func::Const{typeof(norm)},
147+
::Type{<:Const},
148+
cache,
149+
A::Annotation{<:AbstractTensorMap},
150+
p::Const{<:Real},
151+
)
152+
return (nothing, nothing)
153+
end
114154
function EnzymeRules.augmented_primal(
115155
config::EnzymeRules.RevConfigWidth{1},
116156
func::Const{typeof(inv)},

ext/TensorKitEnzymeExt/utility.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ end
5454
# Ignore derivatives
5555
# ------------------
5656

57+
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true
58+
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true
59+
5760
@inline EnzymeRules.inactive(::typeof(TensorKit.fusionblockstructure), arg) = nothing
5861
@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing
5962
@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing

ext/TensorKitEnzymeTestUtilsExt.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module TensorKitEnzymeTestUtilsExt
2+
3+
using TensorKit
4+
using EnzymeTestUtils
5+
using EnzymeTestUtils: Enzyme
6+
import EnzymeTestUtils: to_vec, from_vec, rand_tangent
7+
8+
function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
9+
has_seen = haskey(seen_vecs, x)
10+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
11+
if has_seen || is_const
12+
x_vec = Float32[]
13+
else
14+
vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)]
15+
x_vec, back = to_vec(vec_of_vecs)
16+
seen_vecs[x] = x_vec
17+
end
18+
function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict)
19+
if xor(has_seen, haskey(seen_xs, x))
20+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
21+
end
22+
has_seen && return seen_xs[x]
23+
is_const && return x
24+
25+
x_new = similar(x)
26+
xvec_of_vecs = back(x_vec_new)
27+
for (i, (c, b)) in enumerate(blocks(x_new))
28+
scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c))
29+
end
30+
if Core.Typeof(x_new) != Core.Typeof(x)
31+
x_new = Core.Typeof(x)(x_new)
32+
end
33+
seen_xs[x] = x_new
34+
return x_new
35+
end
36+
return x_vec, TensorMap_from_vec
37+
end
38+
function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
39+
parent_vec, parent_t = to_vec(parent(t), seen_vecs)
40+
return parent_vec, adjoint parent_t
41+
end
42+
function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
43+
parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs)
44+
return parent_vec, TensorKit.DiagonalTensorMap parent_t
45+
end
46+
47+
# generate random tangents for testing
48+
function EnzymeTestUtils.rand_tangent(rng, t::TensorMap)
49+
return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t))
50+
end
51+
52+
function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap)
53+
return adjoint(rand_tangent(rng, parent(t)))
54+
end
55+
56+
function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap)
57+
return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1))
58+
end
59+
60+
end

test/enzyme/factorizations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using MatrixAlgebraKit
66
using Enzyme, EnzymeTestUtils
77
using Random
88

9-
@isdefined(TestSetup) || include("../setup.jl")
9+
@isdefined(TestSetup) || include("../../setup.jl")
1010
using .TestSetup
1111

1212
spacelist = (

test/enzyme/factorizations/svd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using MatrixAlgebraKit
66
using Enzyme, EnzymeTestUtils
77
using Random
88

9-
@isdefined(TestSetup) || include("../setup.jl")
9+
@isdefined(TestSetup) || include("../../setup.jl")
1010
using .TestSetup
1111

1212
spacelist = (
@@ -58,7 +58,7 @@ end
5858
atol = default_tol(T)
5959
rtol = default_tol(T)
6060
USVᴴ = svd_compact(t)
61-
ΔUSVᴴ = (TensorMap(randn!(similar(USVᴴ[1].data)), space(USVᴴ[1])), DiagonalTensorMap(randn!(similar(USVᴴ[2].data)), space(USVᴴ[2], 1)), TensorMap(randn!(similar(USVᴴ[3].data)), space(USVᴴ[3])))
61+
ΔUSVᴴ = EnzymeTestUtils.rand_tangent.(USVᴴ)
6262
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
6363
EnzymeTestUtils.test_reverse(svd_compact, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol)
6464

@@ -70,8 +70,8 @@ end
7070
V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
7171
trunc = truncspace(V_trunc)
7272
alg = MatrixAlgebraKit.select_algorithm(svd_trunc_no_error, t, nothing; trunc)
73-
USVᴴtrunc = svd_trunc(t, alg)
74-
ΔUSVᴴtrunc = randn!(similar.(USVᴴtrunc))
73+
USVᴴtrunc = svd_trunc_no_error(t, alg)
74+
ΔUSVᴴtrunc = EnzymeTestUtils.rand_tangent.(USVᴴtrunc)
7575
remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], USVᴴtrunc...)
7676
EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol)
7777
end

test/enzyme/indexmanipulations/add_braid.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,21 @@ spacelist = (
4242
eltypes = (Float64, ComplexF64)
4343

4444
@timedtestset "Enzyme - Index Manipulations (add_braid!):" begin
45-
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
45+
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)$Tα$Tβ" for V in spacelist, T in eltypes, Tα in (Active, Const), Tβ in (Active, Const)
4646
atol = default_tol(T)
4747
rtol = default_tol(T)
4848
Vstr = TensorKit.type_repr(sectortype(eltype(V)))
49-
@timedtestset "add_braid! Tα $Tα$Tβ" forin (Active, Const), Tβ in (Active, Const)
50-
A = randn(T, V[1] V[2] V[4] V[5])
51-
α = randn(T)
52-
β = randn(T)
53-
p = randcircshift(numout(A), numin(A))
54-
levels = Tuple(randperm(numind(A)))
55-
C = randn!(transpose(A, p))
56-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
57-
if !(T <: Real)
58-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
59-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
60-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
61-
end
49+
A = randn(T, V[1] V[2] V[4] V[5])
50+
α = randn(T)
51+
β = randn(T)
52+
p = randcircshift(numout(A), numin(A))
53+
levels = Tuple(randperm(numind(A)))
54+
C = randn!(transpose(A, p))
55+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
56+
if !(T <: Real)
57+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
58+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
59+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
6260
end
6361
end
6462
end

0 commit comments

Comments
 (0)