Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.13.75"


[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Expand All @@ -18,13 +17,13 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand All @@ -33,6 +32,7 @@ EnzymeBFloat16sExt = "BFloat16s"
EnzymeChainRulesCoreExt = "ChainRulesCore"
EnzymeGPUArraysCoreExt = "GPUArraysCore"
EnzymeLogExpFunctionsExt = "LogExpFunctions"
EnzymeSparseArraysExt = "SparseArrays"
EnzymeSpecialFunctionsExt = "SpecialFunctions"
EnzymeStaticArraysExt = "StaticArrays"

Expand Down
181 changes: 181 additions & 0 deletions ext/EnzymeSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
module EnzymeSparseArraysExt

using LinearAlgebra: LinearAlgebra
using SparseArrays: SparseArrays
using Enzyme
using EnzymeCore: EnzymeRules

@inline Enzyme.Compiler.ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where {T} = T
@inline Enzyme.Compiler.is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where {T} = true

Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_dense_struct}) = true
Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_sparse_struct}) = true
Enzyme.Compiler.isa_cholmod_struct(::Core.Type{<:SparseArrays.LibSuiteSparse.cholmod_factor_struct}) = true

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT},
C::Annotation{<:StridedVecOrMat},
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C = !(isa(β, Const)) ? copy(C.val) : nothing
# Always need to do forward pass otherwise primal may not be correct
func.val(C.val, A.val, B.val, α.val, β.val)

primal = if EnzymeRules.needs_primal(config)
C.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
C.dval
else
nothing
end


# Check if A is overwritten and B is active (and thus required)
cache_A = (
EnzymeRules.overwritten(config)[5]
&& !(typeof(B) <: Const)
&& !(typeof(C) <: Const)
) ? copy(A.val) : nothing

cache_B = (
EnzymeRules.overwritten(config)[6]
&& !(typeof(A) <: Const)
&& !(typeof(C) <: Const)
) ? copy(B.val) : nothing

if !isa(α, Const)
cache_α = A.val * B.val
else
cache_α = nothing
end

cache = (cache_C, cache_A, cache_B, cache_α)

return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT}, cache,
C::Annotation{<:StridedVecOrMat},
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C, cache_A, cache_B, cache_α = cache
Cval = !isnothing(cache_C) ? cache_C : C.val
Aval = !isnothing(cache_A) ? cache_A : A.val
Bval = !isnothing(cache_B) ? cache_B : B.val

N = EnzymeRules.width(config)
if !isa(C, Const)
dCs = C.dval
dBs = isa(B, Const) ? dCs : B.dval
dα = if !isa(α, Const)
if N == 1
Enzyme._project(typeof(α.val), conj(LinearAlgebra.dot(C.dval, cache_α)))
else
ntuple(Val(N)) do i
Base.@_inline_meta
Enzyme._project(typeof(α.val), conj(LinearAlgebra.dot(C.dval[i], cache_α)))
end
end
else
nothing
end

dβ = if !isa(β, Const)
if N == 1
Enzyme._project(typeof(β.val), conj(LinearAlgebra.dot(C.dval, Cval)))
else
ntuple(Val(N)) do i
Base.@_inline_meta
Enzyme._project(typeof(β.val), conj(LinearAlgebra.dot(C.dval[i], Cval)))
end
end
else
nothing
end

for i in 1:N
if !isa(A, Const)
# dA .+= α'dC*B'
# You need to be careful so that dA sparsity pattern does not change. Otherwise
# you will get incorrect gradients. So for now we do the slow and bad way of accumulating
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i]
dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i]
# Now accumulate to preserve the correct sparsity pattern
I, J, _ = SparseArrays.findnz(dA)
for k in eachindex(I, J)
Ik, Jk = I[k], J[k]
# May need to widen if the eltype differ
tmp = zero(promote_type(eltype(dA), eltype(dC)))
for ti in axes(dC, 2)
tmp += dC[Ik, ti] * conj(Bval[Jk, ti])
end
dA[Ik, Jk] += Enzyme._project(eltype(dA), conj(α.val) * tmp)
end
# mul!(dA, dCs, Bval', α.val, true)
end

if !isa(B, Const)
#dB .+= α*A'*dC
# Get the type of all arguments since we may need to
# project down to a smaller type during accumulation
if N == 1
Targs = promote_type(eltype(Aval), eltype(dCs), typeof(α.val))
Enzyme._muladdproject!(Targs, dBs, Aval', dCs, conj(α.val))
else
Targs = promote_type(eltype(Aval[i]), eltype(dCs[i]), typeof(α.val))
Enzyme._muladdproject!(Targs, dBs[i], Aval', dCs[i], conj(α.val))
end
end
#dC = dC*conj(β.val)
if N == 1
dCs .*= Enzyme._project(eltype(dCs), conj(β.val))
else
dCs[i] .*= Enzyme._project(eltype(dCs[i]), conj(β.val))
end
end
else
# C is constant so there is no gradient information to compute

dα = if !isa(α, Const)
if N == 1
zero(α.val)
else
ntuple(Returns(zero(α.val)), Val(N))
end
else
nothing
end


dβ = if !isa(β, Const)
if N == 1
zero(β.val)
else
ntuple(Returns(zero(β.val)), Val(N))
end
else
nothing
end
end

return (nothing, nothing, nothing, dα, dβ)
end

end
1 change: 0 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export batch_size, onehot, chunkedonehot

using LinearAlgebra
import SparseArrays

import EnzymeCore: EnzymeRules
export EnzymeRules
Expand Down
4 changes: 3 additions & 1 deletion src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
return larg, offset
end

isa_cholmod_struct(typ) = false

function abs_typeof(
@nospecialize(arg::LLVM.Value),
partial::Bool = false, seenphis = Set{LLVM.PHIInst}()
Expand Down Expand Up @@ -600,7 +602,7 @@ function abs_typeof(
# add the extra poitner offset when loading here]. However for pointers constructed by ccall outside julia
# to a julia object, which are not inline by type but appear so, like SparseArrays, this is a problem
# and merits further investigation. x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/2085
if !Base.allocatedinline(typ) && typ != SparseArrays.cholmod_dense_struct && typ != SparseArrays.cholmod_sparse_struct && typ != SparseArrays.cholmod_factor_struct
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately presently this cannot be moved into an extension (it will not get picked up from generated function shenanigans). the longer term soltn i need to figure out is how to handle this specific edge case of a ccall return, but I haven't gotten to so yet [and it only came up for sparsearrays]

Copy link
Member

@vchuravy vchuravy Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would need to perform an invoke in world (probably invole_in_world_total) with the world being the job.world

if !Base.allocatedinline(typ) && !isa_cholmod_struct(typ)
shouldLoad = false
offset %= sizeof(Int)
else
Expand Down
2 changes: 0 additions & 2 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ end
@inline ptreltype(::Type{Tuple{Vararg{T}}}) where {T} = T
@inline ptreltype(::Type{IdDict{K,V}}) where {K,V} = V
@inline ptreltype(::Type{IdDict{K,V} where K}) where {V} = V
@inline ptreltype(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = T
@static if VERSION < v"1.11-"
else
@inline ptreltype(::Type{Memory{T}}) where T = T
Expand All @@ -95,7 +94,6 @@ end
@inline is_arrayorvararg_ty(::Type{Base.RefValue{T}}) where {T} = true
@inline is_arrayorvararg_ty(::Type{IdDict{K,V}}) where {K,V} = true
@inline is_arrayorvararg_ty(::Type{IdDict{K,V} where K}) where {V} = true
@inline is_arrayorvararg_ty(::Type{SparseArrays.CHOLMOD.Dense{T}}) where T = true
@static if VERSION < v"1.11-"
else
@inline is_arrayorvararg_ty(::Type{Memory{T}}) where T = true
Expand Down
1 change: 0 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import Enzyme_jll
import GPUCompiler: CompilerJob, compile, safe_name
using LLVM.Interop
import LLVM: Target, TargetMachine
import SparseArrays
using Printf

using Preferences
Expand Down
Loading
Loading