Skip to content

Commit edabc4a

Browse files
authored
Use Adapt.jl (#57)
1 parent 4aba375 commit edabc4a

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ name = "OffsetArrays"
22
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
33
version = "1.3.1"
44

5+
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7+
58
[compat]
9+
Adapt = "2"
610
julia = "0.7, 1"
711

812
[extras]
@@ -12,7 +16,8 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1216
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1317
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
1418
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1520
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1621

1722
[targets]
18-
test = ["Aqua", "CatIndices", "DelimitedFiles", "Documenter", "Test", "LinearAlgebra", "EllipsisNotation"]
23+
test = ["Aqua", "CatIndices", "DelimitedFiles", "Documenter", "Test", "LinearAlgebra", "EllipsisNotation", "StaticArrays"]

src/OffsetArrays.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ parenttype(A::OffsetArray) = parenttype(typeof(A))
198198

199199
Base.parent(A::OffsetArray) = A.parent
200200

201+
# TODO: Ideally we would delegate to the parent's broadcasting implementation, but that
202+
# is currently broken in sufficiently many implementation, namely RecursiveArrayTools, DistributedArrays
203+
# and StaticArrays, that it will take concentrated effort to get this working across the ecosystem.
204+
# The goal would be to have `OffsetArray(CuArray) .+ 1 == OffsetArray{CuArray}`.
205+
# Base.Broadcast.BroadcastStyle(::Type{<:OffsetArray{<:Any, <:Any, AA}}) where AA = Base.Broadcast.BroadcastStyle(AA)
206+
201207
Base.eachindex(::IndexCartesian, A::OffsetArray) = CartesianIndices(axes(A))
202208
Base.eachindex(::IndexLinear, A::OffsetVector) = axes(A, 1)
203209

@@ -473,4 +479,10 @@ if VERSION < v"1.1.0-DEV.783"
473479
Base.copyfirst!(dest::OffsetArray, src::OffsetArray) = (maximum!(parent(dest), parent(src)); return dest)
474480
end
475481

482+
##
483+
# Adapt allows for automatic conversion of CPU OffsetArrays to GPU OffsetArrays
484+
##
485+
import Adapt
486+
Adapt.adapt_structure(to, x::OffsetArray) = OffsetArray(Adapt.adapt(to, parent(x)), x.offsets)
487+
476488
end # module

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using LinearAlgebra
66
using DelimitedFiles
77
using CatIndices: BidirectionalVector
88
using EllipsisNotation
9+
using Adapt
10+
using StaticArrays
911

1012
# https://github.com/JuliaLang/julia/pull/29440
1113
if VERSION < v"1.1.0-DEV.389"
@@ -1378,4 +1380,14 @@ end
13781380
@test searchsorted(o, 6) == 3:2
13791381
end
13801382

1383+
@testset "Adapt" begin
1384+
# We need another storage type, CUDA.jl defines one but we can't use that for CI
1385+
# let's define an appropriate method for SArrays
1386+
Adapt.adapt_storage(::Type{SA}, xs::AbstractArray) where SA<:SArray = convert(SA, xs)
1387+
arr = OffsetArray(rand(3, 3), -1:1, -1:1)
1388+
s_arr = adapt(SMatrix{3,3}, arr)
1389+
@test parent(s_arr) isa SArray
1390+
@test arr == adapt(Array, s_arr)
1391+
end
1392+
13811393
include("origin.jl")

0 commit comments

Comments
 (0)