Skip to content

Commit 3d9ca6c

Browse files
committed
refactor: move SparseArrays into an extension
1 parent 345b215 commit 3d9ca6c

File tree

7 files changed

+186
-173
lines changed

7 files changed

+186
-173
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
44
version = "0.13.75"
55

6-
76
[deps]
87
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
98
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
@@ -18,13 +17,13 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1817
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1918
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2019
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2220

2321
[weakdeps]
2422
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
2523
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2624
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2725
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
26+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2827
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2928
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3029

@@ -33,6 +32,7 @@ EnzymeBFloat16sExt = "BFloat16s"
3332
EnzymeChainRulesCoreExt = "ChainRulesCore"
3433
EnzymeGPUArraysCoreExt = "GPUArraysCore"
3534
EnzymeLogExpFunctionsExt = "LogExpFunctions"
35+
EnzymeSparseArraysExt = "SparseArrays"
3636
EnzymeSpecialFunctionsExt = "SpecialFunctions"
3737
EnzymeStaticArraysExt = "StaticArrays"
3838

ext/EnzymeSparseArraysExt.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
module EnzymeSparseArraysExt
2+
3+
using LinearAlgebra: LinearAlgebra
4+
using SparseArrays: SparseArrays
5+
using Enzyme
6+
using EnzymeCore: EnzymeRules
7+
8+
@inline Enzyme.Compiler.ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where {T} = T
9+
@inline Enzyme.Compiler.is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where {T} = true
10+
11+
Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_dense_struct}) = true
12+
Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_sparse_struct}) = true
13+
Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_factor_struct}) = true
14+
15+
function EnzymeRules.augmented_primal(
16+
config::EnzymeRules.RevConfig,
17+
func::Const{typeof(LinearAlgebra.mul!)},
18+
::Type{RT},
19+
C::Annotation{<:StridedVecOrMat},
20+
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
21+
B::Annotation{<:StridedVecOrMat},
22+
α::Annotation{<:Number},
23+
β::Annotation{<:Number}
24+
) where {RT}
25+
26+
cache_C = !(isa(β, Const)) ? copy(C.val) : nothing
27+
# Always need to do forward pass otherwise primal may not be correct
28+
func.val(C.val, A.val, B.val, α.val, β.val)
29+
30+
primal = if EnzymeRules.needs_primal(config)
31+
C.val
32+
else
33+
nothing
34+
end
35+
36+
shadow = if EnzymeRules.needs_shadow(config)
37+
C.dval
38+
else
39+
nothing
40+
end
41+
42+
43+
# Check if A is overwritten and B is active (and thus required)
44+
cache_A = (
45+
EnzymeRules.overwritten(config)[5]
46+
&& !(typeof(B) <: Const)
47+
&& !(typeof(C) <: Const)
48+
) ? copy(A.val) : nothing
49+
50+
cache_B = (
51+
EnzymeRules.overwritten(config)[6]
52+
&& !(typeof(A) <: Const)
53+
&& !(typeof(C) <: Const)
54+
) ? copy(B.val) : nothing
55+
56+
if !isa(α, Const)
57+
cache_α = A.val * B.val
58+
else
59+
cache_α = nothing
60+
end
61+
62+
cache = (cache_C, cache_A, cache_B, cache_α)
63+
64+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
65+
end
66+
67+
function EnzymeRules.reverse(
68+
config::EnzymeRules.RevConfig,
69+
func::Const{typeof(LinearAlgebra.mul!)},
70+
::Type{RT}, cache,
71+
C::Annotation{<:StridedVecOrMat},
72+
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
73+
B::Annotation{<:StridedVecOrMat},
74+
α::Annotation{<:Number},
75+
β::Annotation{<:Number}
76+
) where {RT}
77+
78+
cache_C, cache_A, cache_B, cache_α = cache
79+
Cval = !isnothing(cache_C) ? cache_C : C.val
80+
Aval = !isnothing(cache_A) ? cache_A : A.val
81+
Bval = !isnothing(cache_B) ? cache_B : B.val
82+
83+
N = EnzymeRules.width(config)
84+
if !isa(C, Const)
85+
dCs = C.dval
86+
dBs = isa(B, Const) ? dCs : B.dval
87+
= if !isa(α, Const)
88+
if N == 1
89+
Enzyme._project(typeof.val), conj(LinearAlgebra.dot(C.dval, cache_α)))
90+
else
91+
ntuple(Val(N)) do i
92+
Base.@_inline_meta
93+
Enzyme._project(typeof.val), conj(LinearAlgebra.dot(C.dval[i], cache_α)))
94+
end
95+
end
96+
else
97+
nothing
98+
end
99+
100+
= if !isa(β, Const)
101+
if N == 1
102+
Enzyme._project(typeof.val), conj(LinearAlgebra.dot(C.dval, Cval)))
103+
else
104+
ntuple(Val(N)) do i
105+
Base.@_inline_meta
106+
Enzyme._project(typeof.val), conj(LinearAlgebra.dot(C.dval[i], Cval)))
107+
end
108+
end
109+
else
110+
nothing
111+
end
112+
113+
for i in 1:N
114+
if !isa(A, Const)
115+
# dA .+= α'dC*B'
116+
# You need to be careful so that dA sparsity pattern does not change. Otherwise
117+
# you will get incorrect gradients. So for now we do the slow and bad way of accumulating
118+
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i]
119+
dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i]
120+
# Now accumulate to preserve the correct sparsity pattern
121+
I, J, _ = SparseArrays.findnz(dA)
122+
for k in eachindex(I, J)
123+
Ik, Jk = I[k], J[k]
124+
# May need to widen if the eltype differ
125+
tmp = zero(promote_type(eltype(dA), eltype(dC)))
126+
for ti in axes(dC, 2)
127+
tmp += dC[Ik, ti] * conj(Bval[Jk, ti])
128+
end
129+
dA[Ik, Jk] += Enzyme._project(eltype(dA), conj.val) * tmp)
130+
end
131+
# mul!(dA, dCs, Bval', α.val, true)
132+
end
133+
134+
if !isa(B, Const)
135+
#dB .+= α*A'*dC
136+
# Get the type of all arguments since we may need to
137+
# project down to a smaller type during accumulation
138+
if N == 1
139+
Targs = promote_type(eltype(Aval), eltype(dCs), typeof.val))
140+
Enzyme._muladdproject!(Targs, dBs, Aval', dCs, conj.val))
141+
else
142+
Targs = promote_type(eltype(Aval[i]), eltype(dCs[i]), typeof.val))
143+
Enzyme._muladdproject!(Targs, dBs[i], Aval', dCs[i], conj.val))
144+
end
145+
end
146+
#dC = dC*conj(β.val)
147+
if N == 1
148+
dCs .*= Enzyme._project(eltype(dCs), conj.val))
149+
else
150+
dCs[i] .*= Enzyme._project(eltype(dCs[i]), conj.val))
151+
end
152+
end
153+
else
154+
# C is constant so there is no gradient information to compute
155+
156+
= if !isa(α, Const)
157+
if N == 1
158+
zero.val)
159+
else
160+
ntuple(Returns(zero.val)), Val(N))
161+
end
162+
else
163+
nothing
164+
end
165+
166+
167+
= if !isa(β, Const)
168+
if N == 1
169+
zero.val)
170+
else
171+
ntuple(Returns(zero.val)), Val(N))
172+
end
173+
else
174+
nothing
175+
end
176+
end
177+
178+
return (nothing, nothing, nothing, dα, dβ)
179+
end
180+
181+
end

src/Enzyme.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
114114
export batch_size, onehot, chunkedonehot
115115

116116
using LinearAlgebra
117-
import SparseArrays
118117

119118
import EnzymeCore: EnzymeRules
120119
export EnzymeRules

src/absint.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
326326
return larg, offset
327327
end
328328

329+
isa_cholmod_struct(typ) = false
330+
329331
function abs_typeof(
330332
@nospecialize(arg::LLVM.Value),
331333
partial::Bool = false, seenphis = Set{LLVM.PHIInst}()
@@ -600,7 +602,7 @@ function abs_typeof(
600602
# add the extra poitner offset when loading here]. However for pointers constructed by ccall outside julia
601603
# to a julia object, which are not inline by type but appear so, like SparseArrays, this is a problem
602604
# and merits further investigation. x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/2085
603-
if !Base.allocatedinline(typ) && typ != SparseArrays.cholmod_dense_struct && typ != SparseArrays.cholmod_sparse_struct && typ != SparseArrays.cholmod_factor_struct
605+
if !Base.allocatedinline(typ) && !isa_cholmod_struct(typ)
604606
shouldLoad = false
605607
offset %= sizeof(Int)
606608
else

src/analyses/activity.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ end
8181
@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T
8282
@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V
8383
@inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V
84-
@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T
8584
@static if VERSION < v"1.11-"
8685
else
8786
@inline ptreltype(::Type{Memory{T}}) where T = T
@@ -95,7 +94,6 @@ end
9594
@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true
9695
@inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true
9796
@inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true
98-
@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true
9997
@static if VERSION < v"1.11-"
10098
else
10199
@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true

src/compiler.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ import Enzyme_jll
3838
import GPUCompiler: CompilerJob, compile, safe_name
3939
using LLVM.Interop
4040
import LLVM: Target, TargetMachine
41-
import SparseArrays
4241
using Printf
4342

4443
using Preferences

0 commit comments

Comments
 (0)