From 3d9ca6ce964b94795d2b9ed94da3c2fca7c1cb7e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Sep 2025 22:25:38 -0400 Subject: [PATCH] refactor: move SparseArrays into an extension --- Project.toml | 4 +- ext/EnzymeSparseArraysExt.jl | 181 +++++++++++++++++++++++++++++++++++ src/Enzyme.jl | 1 - src/absint.jl | 4 +- src/analyses/activity.jl | 2 - src/compiler.jl | 1 - src/internal_rules.jl | 166 -------------------------------- 7 files changed, 186 insertions(+), 173 deletions(-) create mode 100644 ext/EnzymeSparseArraysExt.jl diff --git a/Project.toml b/Project.toml index ae55bc0331..45eede857d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] version = "0.13.75" - [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" @@ -18,13 +17,13 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -33,6 +32,7 @@ EnzymeBFloat16sExt = "BFloat16s" EnzymeChainRulesCoreExt = "ChainRulesCore" EnzymeGPUArraysCoreExt = "GPUArraysCore" EnzymeLogExpFunctionsExt = "LogExpFunctions" +EnzymeSparseArraysExt = "SparseArrays" EnzymeSpecialFunctionsExt = "SpecialFunctions" EnzymeStaticArraysExt = "StaticArrays" diff --git a/ext/EnzymeSparseArraysExt.jl b/ext/EnzymeSparseArraysExt.jl new file mode 100644 index 0000000000..a6e85eeecd --- /dev/null +++ b/ext/EnzymeSparseArraysExt.jl @@ -0,0 +1,181 @@ +module EnzymeSparseArraysExt + +using LinearAlgebra: LinearAlgebra +using SparseArrays: SparseArrays +using Enzyme +using EnzymeCore: EnzymeRules + +@inline Enzyme.Compiler.ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where {T} = T +@inline Enzyme.Compiler.is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where {T} = true + +Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_dense_struct}) = true +Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_sparse_struct}) = true +Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_factor_struct}) = true + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, + C::Annotation{<:StridedVecOrMat}, + A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C = !(isa(β, Const)) ? copy(C.val) : nothing + # Always need to do forward pass otherwise primal may not be correct + func.val(C.val, A.val, B.val, α.val, β.val) + + primal = if EnzymeRules.needs_primal(config) + C.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + C.dval + else + nothing + end + + + # Check if A is overwritten and B is active (and thus required) + cache_A = ( + EnzymeRules.overwritten(config)[5] + && !(typeof(B) <: Const) + && !(typeof(C) <: Const) + ) ? copy(A.val) : nothing + + cache_B = ( + EnzymeRules.overwritten(config)[6] + && !(typeof(A) <: Const) + && !(typeof(C) <: Const) + ) ? copy(B.val) : nothing + + if !isa(α, Const) + cache_α = A.val * B.val + else + cache_α = nothing + end + + cache = (cache_C, cache_A, cache_B, cache_α) + + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(LinearAlgebra.mul!)}, + ::Type{RT}, cache, + C::Annotation{<:StridedVecOrMat}, + A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, + B::Annotation{<:StridedVecOrMat}, + α::Annotation{<:Number}, + β::Annotation{<:Number} + ) where {RT} + + cache_C, cache_A, cache_B, cache_α = cache + Cval = !isnothing(cache_C) ? cache_C : C.val + Aval = !isnothing(cache_A) ? cache_A : A.val + Bval = !isnothing(cache_B) ? cache_B : B.val + + N = EnzymeRules.width(config) + if !isa(C, Const) + dCs = C.dval + dBs = isa(B, Const) ? dCs : B.dval + dα = if !isa(α, Const) + if N == 1 + Enzyme._project(typeof(α.val), conj(LinearAlgebra.dot(C.dval, cache_α))) + else + ntuple(Val(N)) do i + Base.@_inline_meta + Enzyme._project(typeof(α.val), conj(LinearAlgebra.dot(C.dval[i], cache_α))) + end + end + else + nothing + end + + dβ = if !isa(β, Const) + if N == 1 + Enzyme._project(typeof(β.val), conj(LinearAlgebra.dot(C.dval, Cval))) + else + ntuple(Val(N)) do i + Base.@_inline_meta + Enzyme._project(typeof(β.val), conj(LinearAlgebra.dot(C.dval[i], Cval))) + end + end + else + nothing + end + + for i in 1:N + if !isa(A, Const) + # dA .+= α'dC*B' + # You need to be careful so that dA sparsity pattern does not change. Otherwise + # you will get incorrect gradients. So for now we do the slow and bad way of accumulating + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i] + dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i] + # Now accumulate to preserve the correct sparsity pattern + I, J, _ = SparseArrays.findnz(dA) + for k in eachindex(I, J) + Ik, Jk = I[k], J[k] + # May need to widen if the eltype differ + tmp = zero(promote_type(eltype(dA), eltype(dC))) + for ti in axes(dC, 2) + tmp += dC[Ik, ti] * conj(Bval[Jk, ti]) + end + dA[Ik, Jk] += Enzyme._project(eltype(dA), conj(α.val) * tmp) + end + # mul!(dA, dCs, Bval', α.val, true) + end + + if !isa(B, Const) + #dB .+= α*A'*dC + # Get the type of all arguments since we may need to + # project down to a smaller type during accumulation + if N == 1 + Targs = promote_type(eltype(Aval), eltype(dCs), typeof(α.val)) + Enzyme._muladdproject!(Targs, dBs, Aval', dCs, conj(α.val)) + else + Targs = promote_type(eltype(Aval[i]), eltype(dCs[i]), typeof(α.val)) + Enzyme._muladdproject!(Targs, dBs[i], Aval', dCs[i], conj(α.val)) + end + end + #dC = dC*conj(β.val) + if N == 1 + dCs .*= Enzyme._project(eltype(dCs), conj(β.val)) + else + dCs[i] .*= Enzyme._project(eltype(dCs[i]), conj(β.val)) + end + end + else + # C is constant so there is no gradient information to compute + + dα = if !isa(α, Const) + if N == 1 + zero(α.val) + else + ntuple(Returns(zero(α.val)), Val(N)) + end + else + nothing + end + + + dβ = if !isa(β, Const) + if N == 1 + zero(β.val) + else + ntuple(Returns(zero(β.val)), Val(N)) + end + else + nothing + end + end + + return (nothing, nothing, nothing, dα, dβ) +end + +end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7061339431..98f2cca9bc 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -114,7 +114,6 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export batch_size, onehot, chunkedonehot using LinearAlgebra -import SparseArrays import EnzymeCore: EnzymeRules export EnzymeRules diff --git a/src/absint.jl b/src/absint.jl index c98e6fabcc..16640021b2 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -326,6 +326,8 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo return larg, offset end +isa_cholmod_struct(typ) = false + function abs_typeof( @nospecialize(arg::LLVM.Value), partial::Bool = false, seenphis = Set{LLVM.PHIInst}() @@ -600,7 +602,7 @@ function abs_typeof( # add the extra poitner offset when loading here]. However for pointers constructed by ccall outside julia # to a julia object, which are not inline by type but appear so, like SparseArrays, this is a problem # and merits further investigation. x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/2085 - if !Base.allocatedinline(typ) && typ != SparseArrays.cholmod_dense_struct && typ != SparseArrays.cholmod_sparse_struct && typ != SparseArrays.cholmod_factor_struct + if !Base.allocatedinline(typ) && !isa_cholmod_struct(typ) shouldLoad = false offset %= sizeof(Int) else diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 477583757c..fac4476dbc 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -81,7 +81,6 @@ end @inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T @inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V @inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V -@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T @static if VERSION < v"1.11-" else @inline ptreltype(::Type{Memory{T}}) where T = T @@ -95,7 +94,6 @@ end @inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true @inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true -@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true @static if VERSION < v"1.11-" else @inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true diff --git a/src/compiler.jl b/src/compiler.jl index 0cf75eb06f..4124c14ece 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -38,7 +38,6 @@ import Enzyme_jll import GPUCompiler: CompilerJob, compile, safe_name using LLVM.Interop import LLVM: Target, TargetMachine -import SparseArrays using Printf using Preferences diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 1da1f0f186..1fbcfa654c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -741,58 +741,6 @@ function EnzymeRules.reverse( end -function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfig, - func::Const{typeof(LinearAlgebra.mul!)}, - ::Type{RT}, - C::Annotation{<:StridedVecOrMat}, - A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, - B::Annotation{<:StridedVecOrMat}, - α::Annotation{<:Number}, - β::Annotation{<:Number} - ) where {RT} - - cache_C = !(isa(β, Const)) ? copy(C.val) : nothing - # Always need to do forward pass otherwise primal may not be correct - func.val(C.val, A.val, B.val, α.val, β.val) - - primal = if EnzymeRules.needs_primal(config) - C.val - else - nothing - end - - shadow = if EnzymeRules.needs_shadow(config) - C.dval - else - nothing - end - - - # Check if A is overwritten and B is active (and thus required) - cache_A = ( - EnzymeRules.overwritten(config)[5] - && !(typeof(B) <: Const) - && !(typeof(C) <: Const) - ) ? copy(A.val) : nothing - - cache_B = ( - EnzymeRules.overwritten(config)[6] - && !(typeof(A) <: Const) - && !(typeof(C) <: Const) - ) ? copy(B.val) : nothing - - if !isa(α, Const) - cache_α = A.val * B.val - else - cache_α = nothing - end - - cache = (cache_C, cache_A, cache_B, cache_α) - - return EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - # This is required to handle arugments that mix real and complex numbers _project(::Type{<:Real}, x) = x _project(::Type{<:Real}, x::Complex) = real(x) @@ -808,120 +756,6 @@ function _muladdproject!(::Type{<:Complex}, dB::AbstractArray{<:Real}, A::Abstra end -function EnzymeRules.reverse( - config::EnzymeRules.RevConfig, - func::Const{typeof(LinearAlgebra.mul!)}, - ::Type{RT}, cache, - C::Annotation{<:StridedVecOrMat}, - A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, - B::Annotation{<:StridedVecOrMat}, - α::Annotation{<:Number}, - β::Annotation{<:Number} - ) where {RT} - - cache_C, cache_A, cache_B, cache_α = cache - Cval = !isnothing(cache_C) ? cache_C : C.val - Aval = !isnothing(cache_A) ? cache_A : A.val - Bval = !isnothing(cache_B) ? cache_B : B.val - - N = EnzymeRules.width(config) - if !isa(C, Const) - dCs = C.dval - dBs = isa(B, Const) ? dCs : B.dval - dα = if !isa(α, Const) - if N == 1 - _project(typeof(α.val), conj(LinearAlgebra.dot(C.dval, cache_α))) - else - ntuple(Val(N)) do i - Base.@_inline_meta - _project(typeof(α.val), conj(LinearAlgebra.dot(C.dval[i], cache_α))) - end - end - else - nothing - end - - dβ = if !isa(β, Const) - if N == 1 - _project(typeof(β.val), conj(LinearAlgebra.dot(C.dval, Cval))) - else - ntuple(Val(N)) do i - Base.@_inline_meta - _project(typeof(β.val), conj(LinearAlgebra.dot(C.dval[i], Cval))) - end - end - else - nothing - end - - for i in 1:N - if !isa(A, Const) - # dA .+= α'dC*B' - # You need to be careful so that dA sparsity pattern does not change. Otherwise - # you will get incorrect gradients. So for now we do the slow and bad way of accumulating - dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i] - dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i] - # Now accumulate to preserve the correct sparsity pattern - I, J, _ = SparseArrays.findnz(dA) - for k in eachindex(I, J) - Ik, Jk = I[k], J[k] - # May need to widen if the eltype differ - tmp = zero(promote_type(eltype(dA), eltype(dC))) - for ti in axes(dC, 2) - tmp += dC[Ik, ti] * conj(Bval[Jk, ti]) - end - dA[Ik, Jk] += _project(eltype(dA), conj(α.val) * tmp) - end - # mul!(dA, dCs, Bval', α.val, true) - end - - if !isa(B, Const) - #dB .+= α*A'*dC - # Get the type of all arguments since we may need to - # project down to a smaller type during accumulation - if N == 1 - Targs = promote_type(eltype(Aval), eltype(dCs), typeof(α.val)) - _muladdproject!(Targs, dBs, Aval', dCs, conj(α.val)) - else - Targs = promote_type(eltype(Aval[i]), eltype(dCs[i]), typeof(α.val)) - _muladdproject!(Targs, dBs[i], Aval', dCs[i], conj(α.val)) - end - end - #dC = dC*conj(β.val) - if N == 1 - dCs .*= _project(eltype(dCs), conj(β.val)) - else - dCs[i] .*= _project(eltype(dCs[i]), conj(β.val)) - end - end - else - # C is constant so there is no gradient information to compute - - dα = if !isa(α, Const) - if N == 1 - zero(α.val) - else - ntuple(Returns(zero(α.val)), Val(N)) - end - else - nothing - end - - - dβ = if !isa(β, Const) - if N == 1 - zero(β.val) - else - ntuple(Returns(zero(β.val)), Val(N)) - end - else - nothing - end - end - - return (nothing, nothing, nothing, dα, dβ) -end -