|
1 | 1 | module DifferentiationInterfaceGPUArraysCoreExt |
2 | 2 |
|
3 | 3 | import DifferentiationInterface as DI |
4 | | -using GPUArraysCore: AbstractGPUArray |
| 4 | +using GPUArraysCore: @allowscalar, AbstractGPUArray |
5 | 5 |
|
6 | | -""" |
7 | | - OneElement |
8 | | -
|
9 | | -Efficient storage for a one-hot array, aka an array in the standard Euclidean basis. |
10 | | -""" |
11 | | -struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N} |
12 | | - ind::I |
13 | | - val::T |
14 | | - a::A |
15 | | - |
16 | | - function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}} |
17 | | - right_ind = eachindex(a)[ind] |
18 | | - return new{typeof(right_ind),N,T,A}(right_ind, val, a) |
19 | | - end |
20 | | - |
21 | | - function OneElement( |
22 | | - ind::CartesianIndex{N}, val::T, a::A |
23 | | - ) where {N,T,A<:AbstractArray{T,N}} |
24 | | - linear_ind = LinearIndices(a)[ind] |
25 | | - right_ind = eachindex(a)[linear_ind] |
26 | | - return new{typeof(right_ind),N,T,A}(right_ind, val, a) |
27 | | - end |
28 | | -end |
29 | | - |
30 | | -Base.size(oe::OneElement) = size(oe.a) |
31 | | -Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a) |
32 | | - |
33 | | -function Base.getindex(oe::OneElement{<:Integer}, ind::Integer) |
34 | | - return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a))) |
| 6 | +function DI.basis(a::AbstractGPUArray{T}, i) where {T} |
| 7 | + b = similar(a) |
| 8 | + fill!(b, zero(T)) |
| 9 | + @allowscalar b[i] = one(T) |
| 10 | + return b |
35 | 11 | end |
36 | 12 |
|
37 | | -function DI.basis(a::AbstractGPUArray{T}, i) where {T} |
38 | | - b = zero(a) |
39 | | - b .+= OneElement(i, one(T), a) |
| 13 | +function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T} |
| 14 | + b = similar(a) |
| 15 | + fill!(b, zero(T)) |
| 16 | + for i in inds |
| 17 | + @allowscalar b[i] = one(T) |
| 18 | + end |
40 | 19 | return b |
41 | 20 | end |
42 | 21 |
|
|
0 commit comments