Skip to content

Commit bce68b1

Browse files
committed
Now with tangents
1 parent fb52897 commit bce68b1

File tree

15 files changed

+169
-73
lines changed

15 files changed

+169
-73
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ jobs:
3131
- tensors
3232
- other
3333
- mooncake
34-
- enzyme
34+
- enzyme/factorizations
35+
- enzyme/linalg
36+
- enzyme/tensoroperations
37+
- enzyme/vectorinterface
38+
- enzyme/indexmanipulations
3539
- chainrules
3640
os:
3741
- ubuntu-latest

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2222
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2424
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
25+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
2526
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2627
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2728
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
@@ -31,6 +32,7 @@ TensorKitAdaptExt = "Adapt"
3132
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3233
TensorKitChainRulesCoreExt = "ChainRulesCore"
3334
TensorKitEnzymeExt = "Enzyme"
35+
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
3436
TensorKitFiniteDifferencesExt = "FiniteDifferences"
3537
TensorKitMooncakeExt = "Mooncake"
3638

ext/TensorKitEnzymeExt/indexmanipulations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ function EnzymeRules.augmented_primal(
9090
β::Annotation{<:Number},
9191
ba::Const...
9292
) where {RT}
93-
C_cache = !isa(β, Const) ? copy(C.val) : nothing
94-
A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
93+
C_cache = !isa(β, Const) ? deepcopy(C.val) : nothing
94+
A_cache = EnzymeRules.overwritten(config)[3] ? deepcopy(A.val) : nothing
9595
# if we need to compute Δa, it is faster to allocate an intermediate braided A
9696
# and store that instead of repeating the permutation in the pullback each time.
9797
# effectively, we replace `add_permute` by `add ∘ permute`.

ext/TensorKitEnzymeExt/linalg.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,47 @@ function EnzymeRules.reverse(
110110
)
111111
return (nothing,)
112112
end
113-
113+
function EnzymeRules.augmented_primal(
114+
config::EnzymeRules.RevConfigWidth{1},
115+
func::Const{typeof(norm)},
116+
::Type{RT},
117+
A::Annotation{<:AbstractTensorMap},
118+
p::Const{<:Real},
119+
) where {RT}
120+
ret = func.val(A.val, p.val)
121+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
122+
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
123+
cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
124+
cache = (ret, cacheA)
125+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
126+
end
127+
function EnzymeRules.reverse(
128+
config::EnzymeRules.RevConfigWidth{1},
129+
func::Const{typeof(norm)},
130+
dret::Active,
131+
cache,
132+
A::Annotation{<:AbstractTensorMap},
133+
p::Const{<:Real},
134+
)
135+
n, cacheA = cache
136+
Δn = dret.val
137+
Aval = something(cacheA, A.val)
138+
if !isa(A, Const)
139+
x = (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
140+
add!(A.dval, A.val, x)
141+
end
142+
return (nothing, nothing)
143+
end
144+
function EnzymeRules.reverse(
145+
config::EnzymeRules.RevConfigWidth{1},
146+
func::Const{typeof(norm)},
147+
::Type{<:Const},
148+
cache,
149+
A::Annotation{<:AbstractTensorMap},
150+
p::Const{<:Real},
151+
)
152+
return (nothing, nothing)
153+
end
114154
function EnzymeRules.augmented_primal(
115155
config::EnzymeRules.RevConfigWidth{1},
116156
func::Const{typeof(inv)},

ext/TensorKitEnzymeExt/utility.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ end
5454
# Ignore derivatives
5555
# ------------------
5656

57+
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true
58+
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true
59+
5760
@inline EnzymeRules.inactive(::typeof(TensorKit.fusionblockstructure), arg) = nothing
5861
@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing
5962
@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing

ext/TensorKitEnzymeTestUtilsExt.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
module TensorKitEnzymeTestUtilsExt
2+
3+
using TensorKit
4+
using EnzymeTestUtils
5+
using EnzymeTestUtils: Enzyme
6+
import EnzymeTestUtils: to_vec, from_vec, rand_tangent
7+
8+
function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
9+
has_seen = haskey(seen_vecs, x)
10+
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
11+
if has_seen || is_const
12+
x_vec = Float32[]
13+
else
14+
vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)]
15+
x_vec, back = to_vec(vec_of_vecs)
16+
seen_vecs[x] = x_vec
17+
end
18+
function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict)
19+
if xor(has_seen, haskey(seen_xs, x))
20+
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
21+
end
22+
has_seen && return seen_xs[x]
23+
is_const && return x
24+
25+
x_new = similar(x)
26+
xvec_of_vecs = back(x_vec_new)
27+
for (i, (c, b)) in enumerate(blocks(x_new))
28+
scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c))
29+
end
30+
if Core.Typeof(x_new) != Core.Typeof(x)
31+
x_new = Core.Typeof(x)(x_new)
32+
end
33+
seen_xs[x] = x_new
34+
return x_new
35+
end
36+
return x_vec, TensorMap_from_vec
37+
end
38+
function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
39+
parent_vec, parent_t = to_vec(parent(t), seen_vecs)
40+
return parent_vec, adjoint parent_t
41+
end
42+
43+
# generate random tangents for testing
44+
function EnzymeTestUtils.rand_tangent(rng, t::TensorMap)
45+
return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t))
46+
end
47+
48+
function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap)
49+
return adjoint(rand_tangent(rng, parent(t)))
50+
end
51+
52+
function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap)
53+
return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1))
54+
end
55+
56+
end

test/enzyme/indexmanipulations/add_braid.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,21 @@ spacelist = (
4242
eltypes = (Float64, ComplexF64)
4343

4444
@timedtestset "Enzyme - Index Manipulations (add_braid!):" begin
45-
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
45+
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)$Tα$Tβ" for V in spacelist, T in eltypes, Tα in (Active, Const), Tβ in (Active, Const)
4646
atol = default_tol(T)
4747
rtol = default_tol(T)
4848
Vstr = TensorKit.type_repr(sectortype(eltype(V)))
49-
@timedtestset "add_braid! Tα $Tα$Tβ" forin (Active, Const), Tβ in (Active, Const)
50-
A = randn(T, V[1] V[2] V[4] V[5])
51-
α = randn(T)
52-
β = randn(T)
53-
p = randcircshift(numout(A), numin(A))
54-
levels = Tuple(randperm(numind(A)))
55-
C = randn!(transpose(A, p))
56-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
57-
if !(T <: Real)
58-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
59-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
60-
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
61-
end
49+
A = randn(T, V[1] V[2] V[4] V[5])
50+
α = randn(T)
51+
β = randn(T)
52+
p = randcircshift(numout(A), numin(A))
53+
levels = Tuple(randperm(numind(A)))
54+
C = randn!(transpose(A, p))
55+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
56+
if !(T <: Real)
57+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
58+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
59+
EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol, testset_name = "add_braid! V $Vstr$Tα$Tβ")
6260
end
6361
end
6462
end

test/enzyme/indexmanipulations/add_transpose.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,22 @@ spacelist = (
4242
eltypes = (Float64, ComplexF64)
4343

4444
@timedtestset "Enzyme - Index Manipulations (add_transpose!):" begin
45-
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
45+
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)$Tα$Tβ" for V in spacelist, T in eltypes, Tα in (Const, Active), Tβ in (Const, Active)
4646
atol = default_tol(T)
4747
rtol = default_tol(T)
48+
A = randn(T, V[1] V[2] V[4] V[5])
49+
α = randn(T)
50+
β = randn(T)
4851

49-
@timedtestset "add_transpose! Tα $Tα$Tβ" forin (Const, Active), Tβ in (Const, Active)
50-
A = randn(T, V[1] V[2] V[4] V[5])
51-
α = randn(T)
52-
β = randn(T)
53-
54-
# repeat a couple times to get some distribution of arrows
55-
@testet for ri in 1:2
56-
p = randcircshift(numout(A), numin(A))
57-
C = randn!(transpose(A, p))
58-
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (One(), Const), (Zero(), Const); atol, rtol)
59-
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol)
60-
if !(T <: Real)
61-
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol)
62-
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol)
63-
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol)
64-
end
65-
A = C
66-
end
52+
# repeat a couple times to get some distribution of arrows
53+
p = randcircshift(numout(A), numin(A))
54+
C = randn!(transpose(A, p))
55+
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (One(), Const), (Zero(), Const); atol, rtol)
56+
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol)
57+
if !(T <: Real)
58+
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol)
59+
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol)
60+
EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol)
6761
end
6862
end
6963
end

test/enzyme/indexmanipulations/twist.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,16 @@ spacelist = (
4141
)
4242
eltypes = (Float64, ComplexF64)
4343

44-
@timedtestset verbose = true "Enzyme - Index Manipulations (twist):" begin
45-
@timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
44+
@timedtestset "Enzyme - Index Manipulations (twist):" begin
45+
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, TA in (Duplicated,)
4646
atol = default_tol(T)
4747
rtol = default_tol(T)
4848
A = randn(T, V[1] V[2] V[4] V[5])
4949
if !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real))
50-
for TA in (Duplicated,)
51-
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,))
52-
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol, fkwargs = (inv = true,))
53-
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol)
54-
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol)
55-
end
50+
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,))
51+
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol, fkwargs = (inv = true,))
52+
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol)
53+
EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol)
5654
end
5755
end
5856
end

test/enzyme/linalg/mul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ eltypes = (Float64, ComplexF64)
4343
α = randn(T)
4444
β = randn(T)
4545

46-
@testset "mul: TC $TC, TA $TA, TB $TB" for TC in (Const, Duplicated), TA in (Const, Duplicated), TB in (Const, Duplicated)
46+
@testset "mul: TC $TC, TA $TA, TB $TB" for TC in (Duplicated,), TA in (Duplicated,), TB in (Duplicated,)
4747
@testset "$Tα, Tβ $Tβ" forin (Active, Const), Tβ in (Active, Const)
4848
EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol)
4949
end

0 commit comments

Comments
 (0)