Skip to content

Commit fb52897

Browse files
committed
Incremental work on Enzyme support
1 parent 66187bb commit fb52897

File tree

36 files changed

+2673
-5
lines changed

36 files changed

+2673
-5
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- tensors
3232
- other
3333
- mooncake
34+
- enzyme
3435
- chainrules
3536
os:
3637
- ubuntu-latest
@@ -57,6 +58,7 @@ jobs:
5758
- tensors
5859
- other
5960
- mooncake
61+
- enzyme
6062
- chainrules
6163
os:
6264
- ubuntu-latest

Project.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2121
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2222
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2425
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2526
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2627
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
@@ -29,6 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2930
TensorKitAdaptExt = "Adapt"
3031
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3132
TensorKitChainRulesCoreExt = "ChainRulesCore"
33+
TensorKitEnzymeExt = "Enzyme"
3234
TensorKitFiniteDifferencesExt = "FiniteDifferences"
3335
TensorKitMooncakeExt = "Mooncake"
3436

@@ -41,6 +43,8 @@ CUDA = "5.9"
4143
ChainRulesCore = "1"
4244
ChainRulesTestUtils = "1"
4345
Combinatorics = "1"
46+
Enzyme = "0.13.131"
47+
EnzymeTestUtils = "0.2.5"
4448
FiniteDifferences = "0.12"
4549
GPUArrays = "11.3.1"
4650
JET = "0.9, 0.10, 0.11"
@@ -53,7 +57,7 @@ Printf = "1"
5357
Random = "1"
5458
SafeTestsets = "0.1"
5559
ScopedValues = "1.3.0"
56-
Strided = "2"
60+
Strided = "=2.3.3"
5761
TensorKitSectors = "0.3.5"
5862
TensorOperations = "5.1"
5963
Test = "1"
@@ -73,6 +77,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7377
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7478
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
7579
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
80+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
81+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
7682
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
7783
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
7884
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
@@ -86,4 +92,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8692
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
8793

8894
[targets]
89-
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"]
95+
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "Enzyme", "EnzymeTestUtils", "JET"]
96+
97+
[sources]
98+
TensorOperations = {url = "https://github.com/quantumkithub/tensoroperations.jl", rev = "ksh/enzyme_update"}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TensorKitEnzymeExt
2+
3+
using Enzyme
4+
using TensorKit
5+
import TensorKit as TK
6+
using VectorInterface
7+
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
8+
import TensorOperations as TO
9+
using MatrixAlgebraKit
10+
using TupleTools
11+
using Random: AbstractRNG
12+
13+
include("utility.jl")
14+
include("linalg.jl")
15+
include("vectorinterface.jl")
16+
include("tensoroperations.jl")
17+
include("factorizations.jl")
18+
include("indexmanipulations.jl")
19+
#include("planaroperations.jl")
20+
21+
end
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
function EnzymeRules.reverse(
2+
config::EnzymeRules.RevConfigWidth{1},
3+
func::Const{typeof(MatrixAlgebraKit.copy_input)},
4+
::Type{RT},
5+
cache,
6+
f::Annotation,
7+
A::Annotation{<:AbstractTensorMap}
8+
) where {RT}
9+
copy_shadow = cache
10+
if !isa(A, Const) && !isnothing(copy_shadow)
11+
add!(A.dval, copy_shadow)
12+
end
13+
return (nothing, nothing)
14+
end
15+
16+
for (f, pb) in (
17+
(:eig_full, :(MatrixAlgebraKit.eig_pullback!)),
18+
(:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)),
19+
(:lq_compact, :(MatrixAlgebraKit.lq_pullback!)),
20+
(:qr_compact, :(MatrixAlgebraKit.qr_pullback!)),
21+
)
22+
@eval begin
23+
function EnzymeRules.augmented_primal(
24+
config::EnzymeRules.RevConfigWidth{1},
25+
func::Const{typeof($f)},
26+
::Type{RT},
27+
A::Annotation{<:AbstractTensorMap},
28+
alg::Const,
29+
) where {RT}
30+
ret = $f(A.val, alg.val)
31+
dret = make_zero(ret)
32+
cache = (ret, dret)
33+
return EnzymeRules.AugmentedReturn(ret, dret, cache)
34+
end
35+
function EnzymeRules.reverse(
36+
config::EnzymeRules.RevConfigWidth{1},
37+
func::Const{typeof($f)},
38+
::Type{RT},
39+
cache,
40+
A::Annotation{<:AbstractTensorMap},
41+
alg::Const,
42+
) where {RT}
43+
ret, dret = cache
44+
$pb(A.dval, A.val, ret, dret)
45+
return (nothing, nothing)
46+
end
47+
end
48+
end
49+
50+
for f in (:svd_compact, :svd_full)
51+
@eval begin
52+
function EnzymeRules.augmented_primal(
53+
config::EnzymeRules.RevConfigWidth{1},
54+
func::Const{typeof($f)},
55+
::Type{RT},
56+
A::Annotation{<:AbstractTensorMap},
57+
alg::Const,
58+
) where {RT}
59+
USVᴴ = $f(A.val, alg.val)
60+
dUSVᴴ = make_zero(USVᴴ)
61+
cache = (USVᴴ, dUSVᴴ)
62+
return EnzymeRules.AugmentedReturn(USVᴴ, dUSVᴴ, cache)
63+
end
64+
function EnzymeRules.reverse(
65+
config::EnzymeRules.RevConfigWidth{1},
66+
func::Const{typeof($f)},
67+
::Type{RT},
68+
cache,
69+
A::Annotation{<:AbstractTensorMap},
70+
alg::Const,
71+
) where {RT}
72+
USVᴴ, dUSVᴴ = cache
73+
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
74+
return (nothing, nothing)
75+
end
76+
end
77+
78+
# mutating version is not guaranteed to actually mutate
79+
# so we can simply use the non-mutating version instead
80+
f! = Symbol(f, :!)
81+
#=@eval begin
82+
function EnzymeRules.augmented_primal(
83+
config::EnzymeRules.RevConfigWidth{1},
84+
func::Const{typeof($f!)},
85+
::Type{RT},
86+
A::Annotation{<:AbstractTensorMap},
87+
USVᴴ::Annotation,
88+
alg::Const,
89+
) where {RT}
90+
EnzymeRules.augmented_primal(func, RT, A, alg)
91+
end
92+
function EnzymeRules.reverse(
93+
config::EnzymeRules.RevConfigWidth{1},
94+
func::Const{typeof($f!)},
95+
::Type{RT},
96+
cache,
97+
A::Annotation{<:AbstractTensorMap},
98+
USVᴴ::Annotation,
99+
alg::Const,
100+
) where {RT}
101+
EnzymeRules.reverse(func, RT, A, alg)
102+
end
103+
end=# #hmmmm
104+
end
105+
106+
# TODO
107+
#=
108+
function EnzymeRules.augmented_primal(
109+
config::EnzymeRules.RevConfigWidth{1},
110+
func::Const{typeof(svd_trunc)},
111+
::Type{RT},
112+
A::Annotation{<:AbstractTensorMap},
113+
alg::Const,
114+
) where {RT}
115+
116+
USVᴴ = svd_compact(A.val, alg.val.alg)
117+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
118+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
119+
dUSVᴴtrunc = make_zero(USVᴴtrunc)
120+
cache = (USVᴴtrunc, dUSVᴴtrunc)
121+
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
122+
end
123+
function EnzymeRules.reverse(
124+
config::EnzymeRules.RevConfigWidth{1},
125+
func::Const{typeof(svd_trunc)},
126+
::Type{RT},
127+
cache,
128+
A::Annotation{<:AbstractTensorMap},
129+
alg::Const,
130+
) where {RT}
131+
USVᴴ, dUSVᴴ = cache
132+
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
133+
return (nothing, nothing)
134+
end=#

0 commit comments

Comments
 (0)