Skip to content

Commit 272eeb5

Browse files
authored
fix: make (multi)basis work on CuArray (#810)
* fix: make (multi)basis work on CuArray * Add CUDA tests * Changelog * Test drafts on GPU * Install DifferentiationInterface * Run main * Fix * Dev right version of DI * Fix * Cov
1 parent d8905f5 commit 272eeb5

File tree

8 files changed

+80
-48
lines changed

8 files changed

+80
-48
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
steps:
22
- label: "DI GPU tests"
3-
if: |
4-
!build.pull_request.draft &&
5-
build.pull_request.labels includes "gpu"
3+
if: build.pull_request.labels includes "gpu"
64
plugins:
75
- JuliaCI/julia#v1:
86
version: "1"
97
command: |
10-
julia ./DifferentiationInterface/test/GPU/CUDA/simple.jl
8+
julia ./DifferentiationInterface/test/GPU/CUDA/main.jl
119
agents:
1210
queue: "juliagpu"
1311
cuda: "*"

DifferentiationInterface/CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.7.1]
11+
12+
### Fixed
13+
14+
- Make basis work for `CuArray` ([#810])
15+
1016
## [0.7.0]
1117

1218
### Changed
@@ -27,11 +33,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2733

2834
- Allocate Enzyme shadow memory during preparation ([#782])
2935

30-
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.0...main
36+
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.1...main
37+
[0.7.1]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.0...DifferentiationInterface-v0.7.1
3138
[0.7.0]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.54...DifferentiationInterface-v0.7.0
3239
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
3340
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
3441

42+
[#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810
3543
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
3644
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795
3745
[#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790

DifferentiationInterface/Project.toml

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -41,7 +41,9 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4141
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4242
DifferentiationInterfaceGTPSAExt = "GTPSA"
4343
DifferentiationInterfaceMooncakeExt = "Mooncake"
44-
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
44+
DifferentiationInterfacePolyesterForwardDiffExt = [
45+
"PolyesterForwardDiff", "ForwardDiff", "DiffResults"
46+
]
4547
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4648
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4749
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
@@ -121,4 +123,21 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
121123
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
122124

123125
[targets]
124-
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
126+
test = [
127+
"ADTypes",
128+
"Aqua",
129+
"ComponentArrays",
130+
"DataFrames",
131+
"ExplicitImports",
132+
"JET",
133+
"JLArrays",
134+
"JuliaFormatter",
135+
"Pkg",
136+
"Random",
137+
"SparseArrays",
138+
"SparseConnectivityTracer",
139+
"SparseMatrixColorings",
140+
"StableRNGs",
141+
"StaticArrays",
142+
"Test",
143+
]

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,21 @@
11
module DifferentiationInterfaceGPUArraysCoreExt
22

33
import DifferentiationInterface as DI
4-
using GPUArraysCore: AbstractGPUArray
4+
using GPUArraysCore: @allowscalar, AbstractGPUArray
55

6-
"""
7-
OneElement
8-
9-
Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
10-
"""
11-
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
12-
ind::I
13-
val::T
14-
a::A
15-
16-
function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
17-
right_ind = eachindex(a)[ind]
18-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
19-
end
20-
21-
function OneElement(
22-
ind::CartesianIndex{N}, val::T, a::A
23-
) where {N,T,A<:AbstractArray{T,N}}
24-
linear_ind = LinearIndices(a)[ind]
25-
right_ind = eachindex(a)[linear_ind]
26-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
27-
end
28-
end
29-
30-
Base.size(oe::OneElement) = size(oe.a)
31-
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)
32-
33-
function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
34-
return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a)))
6+
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
7+
b = similar(a)
8+
fill!(b, zero(T))
9+
@allowscalar b[i] = one(T)
10+
return b
3511
end
3612

37-
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
38-
b = zero(a)
39-
b .+= OneElement(i, one(T), a)
13+
function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
14+
b = similar(a)
15+
fill!(b, zero(T))
16+
for i in inds
17+
@allowscalar b[i] = one(T)
18+
end
4019
return b
4120
end
4221

DifferentiationInterface/src/utils/basis.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ end
1818
multibasis(a::AbstractArray, inds)
1919
2020
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
21-
22-
!!! warning
23-
Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either.
2421
"""
2522
function multibasis(a::AbstractArray{T}, inds) where {T}
2623
b = similar(a)

DifferentiationInterface/test/Core/Internals/basis.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DifferentiationInterface: basis
1+
using DifferentiationInterface: basis, multibasis
22
using LinearAlgebra
33
using StaticArrays, JLArrays
44
using Test
@@ -8,6 +8,9 @@ using Test
88
@test basis(rand(3), 2) isa Vector
99
@test basis(rand(3), 2) == b_ref
1010
@test basis(jl(rand(3)), 2) isa JLArray
11+
@test Array(basis(jl(rand(3)), 2)) == [0, 1, 0]
12+
@test multibasis(jl(rand(3)), [1, 2]) isa JLArray
13+
@test Array(multibasis(jl(rand(3)), [1, 2])) == [1, 1, 0]
1114
@test all(basis(jl(rand(3)), 2) .== b_ref)
1215
@test basis(@SVector(rand(3)), 2) isa SVector
1316
@test basis(@SVector(rand(3)), 2) == b_ref
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@info "Testing on CUDA"
2+
using Pkg
3+
Pkg.add("CUDA")
4+
Pkg.develop(PackageSpec(; path="./DifferentiationInterface"))
5+
using Test
6+
7+
@testset verbose = true "Simple" begin
8+
include("simple.jl")
9+
end
Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
1-
@info "Testing on CUDA"
2-
using Pkg
3-
Pkg.add("CUDA")
41
using CUDA
2+
using DifferentiationInterface
3+
import DifferentiationInterface as DI
4+
using LinearAlgebra
5+
using Test
6+
57
CUDA.versioninfo()
8+
9+
@testset "Basis" begin
10+
x = CuVector(rand(Float32, 3))
11+
b = DI.basis(x, 2)
12+
@test Array(b) == [0, 1, 0]
13+
14+
X = CuMatrix(rand(Float32, 2, 2))
15+
B = DI.multibasis(X, [2, 3])
16+
@test Array(B) == [0 1; 1 0]
17+
end
18+
19+
@testset "Jacobian" begin
20+
x = CuVector(rand(Float32, 3))
21+
backend = DI.AutoSimpleFiniteDiff()
22+
J = jacobian(identity, backend, x)
23+
@test (J .!= 0) == I
24+
end

0 commit comments

Comments
 (0)