Skip to content

Commit c3ff78e

Browse files
authored
perf: optimize multibasis for sparse differentiation (#777)
* perf: optimize multibasis for sparse differentiation * Add GPUArrays extension * Undo formatting noise * More noise * More noise
1 parent f442ed4 commit c3ff78e

File tree

4 files changed

+59
-64
lines changed

4 files changed

+59
-64
lines changed

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.51"
4+
version = "0.6.52"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -17,6 +17,7 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
1717
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1818
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1919
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
20+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2021
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
2122
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2223
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
@@ -37,6 +38,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3738
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3839
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3940
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
41+
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4042
DifferentiationInterfaceGTPSAExt = "GTPSA"
4143
DifferentiationInterfaceMooncakeExt = "Mooncake"
4244
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
@@ -109,21 +111,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
109111
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
110112

111113
[targets]
112-
test = [
113-
"ADTypes",
114-
"Aqua",
115-
"ComponentArrays",
116-
"DataFrames",
117-
"ExplicitImports",
118-
"JET",
119-
"JLArrays",
120-
"JuliaFormatter",
121-
"Pkg",
122-
"Random",
123-
"SparseArrays",
124-
"SparseConnectivityTracer",
125-
"SparseMatrixColorings",
126-
"StableRNGs",
127-
"StaticArrays",
128-
"Test",
129-
]
114+
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
module DifferentiationInterfaceGPUArraysCoreExt
2+
3+
import DifferentiationInterface as DI
4+
using GPUArraysCore: AbstractGPUArray
5+
6+
"""
7+
OneElement
8+
9+
Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
10+
"""
11+
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
12+
ind::I
13+
val::T
14+
a::A
15+
16+
function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
17+
right_ind = eachindex(a)[ind]
18+
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
19+
end
20+
21+
function OneElement(
22+
ind::CartesianIndex{N}, val::T, a::A
23+
) where {N,T,A<:AbstractArray{T,N}}
24+
linear_ind = LinearIndices(a)[ind]
25+
right_ind = eachindex(a)[linear_ind]
26+
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
27+
end
28+
end
29+
30+
Base.size(oe::OneElement) = size(oe.a)
31+
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)
32+
33+
function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
34+
return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a)))
35+
end
36+
37+
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
38+
b = zero(a)
39+
b .+= OneElement(i, one(T), a)
40+
return b
41+
end
42+
43+
end
Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,3 @@
1-
"""
2-
OneElement
3-
4-
Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
5-
"""
6-
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
7-
ind::I
8-
val::T
9-
a::A
10-
11-
function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
12-
right_ind = eachindex(a)[ind]
13-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
14-
end
15-
16-
function OneElement(
17-
ind::CartesianIndex{N}, val::T, a::A
18-
) where {N,T,A<:AbstractArray{T,N}}
19-
linear_ind = LinearIndices(a)[ind]
20-
right_ind = eachindex(a)[linear_ind]
21-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
22-
end
23-
end
24-
25-
Base.size(oe::OneElement) = size(oe.a)
26-
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)
27-
28-
function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
29-
if ind == oe.ind
30-
return oe.val
31-
else
32-
return zero(eltype(oe.a))
33-
end
34-
end
35-
36-
function Base.getindex(oe::OneElement{<:CartesianIndex{N}}, ind::Vararg{Int,N}) where {N}
37-
if ind == Tuple(oe.ind)
38-
return oe.val
39-
else
40-
return zero(eltype(oe.a))
41-
end
42-
end
43-
441
"""
452
basis(a::AbstractArray, i)
463
@@ -49,7 +6,7 @@ Construct the `i`-th standard basis array in the vector space of `a`.
496
function basis(a::AbstractArray{T}, i) where {T}
507
b = similar(a)
518
fill!(b, zero(T))
52-
b .+= OneElement(i, one(T), a)
9+
b[i] = one(T)
5310
if ismutable_array(a)
5411
return b
5512
else
@@ -61,12 +18,15 @@ end
6118
multibasis(a::AbstractArray, inds)
6219
6320
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
21+
22+
!!! warning
23+
Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either.
6424
"""
6525
function multibasis(a::AbstractArray{T}, inds) where {T}
6626
b = similar(a)
6727
fill!(b, zero(T))
6828
for i in inds
69-
b .+= OneElement(i, one(T), a)
29+
b[i] = one(T)
7030
end
7131
return ismutable_array(a) ? b : map(+, zero(a), b)
7232
end

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DifferentiationInterface:
55
AutoReverseFromPrimitive,
66
DenseSparsityDetector
77
using SparseMatrixColorings
8+
using JLArrays, StaticArrays
89
using Test
910

1011
LOGGING = get(ENV, "CI", "false") == "false"
@@ -137,3 +138,9 @@ end
137138
pushforward, copyto!, [1.0], AutoSimpleFiniteDiff(), [1.0], ([1.0], [1.0])
138139
)
139140
end
141+
142+
@testset "Weird arrays" begin
143+
test_differentiation(
144+
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
145+
)
146+
end;

0 commit comments

Comments
 (0)