Skip to content

Commit 95604fd

Browse files
authored
Merge pull request #195 from Nielsbk/adapt
This adds a file containing Adapt functions for a number of structs. This allows Adapt.adapt to be called with structs to change the underlying data type.
2 parents e3becf6 + 30306c6 commit 95604fd

File tree

9 files changed

+146
-0
lines changed

9 files changed

+146
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Francesc Verdugo <f.verdugo.rojano@vu.nl> and contributors"]
44
version = "0.5.10"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
89
CircularArrays = "7a955b69-7140-5f4e-a0ed-f168c5e2e749"
910
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -18,6 +19,7 @@ SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
1819
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1920

2021
[compat]
22+
Adapt = "4.3.0"
2123
BlockArrays = "0.16, 1"
2224
CircularArrays = "1"
2325
Distances = "0.10"

src/PartitionedArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import MPI
1010
import IterativeSolvers
1111
import Distances
1212
using BlockArrays
13+
using Adapt
1314

1415
export length_to_ptrs!
1516
export rewind_ptrs!
@@ -195,4 +196,5 @@ export nullspace_linear_elasticity!
195196
export near_nullspace_linear_elasticity
196197
include("gallery.jl")
197198

199+
include("adapt.jl")
198200
end # module

src/adapt.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
function Adapt.adapt_structure(to,v::DebugArray)
3+
v = map(v) do val
4+
Adapt.adapt_structure(to,val)
5+
end
6+
end
7+
8+
function Adapt.adapt_structure(to,v::MPIArray)
9+
v = map(v) do val
10+
Adapt.adapt_structure(to,val)
11+
end
12+
end
13+
14+
function Adapt.adapt_structure(to,v::SplitMatrixBlocks)
15+
own_own = Adapt.adapt(to,v.own_own)
16+
own_ghost = Adapt.adapt(to,v.own_ghost)
17+
ghost_ghost = Adapt.adapt(to,v.ghost_ghost)
18+
ghost_own = Adapt.adapt(to,v.ghost_own)
19+
split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost)
20+
end
21+
22+
function Adapt.adapt_structure(to,v::SplitVectorBlocks)
23+
own = Adapt.adapt(to,v.own)
24+
ghost = Adapt.adapt(to,v.ghost)
25+
split_vector_blocks(own,ghost)
26+
end
27+
28+
function Adapt.adapt_structure(to,v::SplitVector)
29+
blocks = Adapt.adapt(to,v.blocks)
30+
perm = Adapt.adapt(to,v.permutation)
31+
split_vector(blocks,perm)
32+
end
33+
34+
function Adapt.adapt_structure(to,v::JaggedArray)
35+
data = Adapt.adapt_structure(to,v.data)
36+
ptrs = Adapt.adapt_structure(to,v.ptrs)
37+
jagged_array(data, ptrs)
38+
end
39+
40+
function Adapt.adapt_structure(to,v::SplitMatrix)
41+
blocks = Adapt.adapt_structure(to,v.blocks)
42+
col_per = v.col_permutation
43+
row_per = v.row_permutation
44+
split_matrix(blocks,row_per,col_per)
45+
end
46+
47+
function Adapt.adapt_structure(to,v::PSparseMatrix)
48+
matrix_partition = Adapt.adapt_structure(to,v.matrix_partition)
49+
col_par = v.col_partition
50+
row_par = v.row_partition
51+
PSparseMatrix(matrix_partition,row_par,col_par,v.assembled)
52+
end

test/adapt_tests.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Test
2+
using PartitionedArrays
3+
using Adapt
4+
5+
struct FakeCuVector{A} <: AbstractVector{Float64}
6+
vector::A
7+
end
8+
9+
Base.size(v::FakeCuVector) = size(v.vector)
10+
Base.getindex(v::FakeCuVector,i::Integer) = v.vector[i]
11+
12+
function Adapt.adapt_storage(::Type{<:FakeCuVector},x::AbstractArray)
13+
FakeCuVector(x)
14+
end
15+
16+
function adapt_tests(distribute)
17+
18+
rank = distribute(LinearIndices((2,2)))
19+
20+
a = [[1,2],[3,4,5],Int[],[3,4]]
21+
b = JaggedArray(a)
22+
c = deepcopy(b)
23+
24+
c = Adapt.adapt(FakeCuVector,c)
25+
26+
@test typeof(c.data) == FakeCuVector{typeof(b.data)}
27+
@test typeof(c.ptrs) == FakeCuVector{typeof(b.ptrs)}
28+
@test typeof(c).name.wrapper == GenericJaggedArray
29+
30+
a = [1,2,3,4,5]
31+
b = deepcopy(a)
32+
b = Adapt.adapt(FakeCuVector,b)
33+
@test typeof(b) == FakeCuVector{typeof(a)}
34+
@test b.vector == a
35+
36+
own = [1,2,3,4]
37+
ghost = [5,6,7,8]
38+
block_a = split_vector_blocks(own, ghost)
39+
block_b = deepcopy(block_a)
40+
block_b = Adapt.adapt(FakeCuVector,block_b)
41+
@test block_b.own.vector == block_a.own
42+
@test block_b.ghost.vector == block_a.ghost
43+
@test typeof(block_b.own) == FakeCuVector{typeof(block_a.own)}
44+
@test typeof(block_b.ghost) == FakeCuVector{typeof(block_a.ghost)}
45+
46+
47+
a = split_vector(block_a,[1,2,3,4,5,6,7,8])
48+
b = deepcopy(a)
49+
b = Adapt.adapt(FakeCuVector,b)
50+
51+
@test b.blocks.own.vector == a.blocks.own
52+
@test b.blocks.ghost.vector == a.blocks.ghost
53+
@test b.permutation.vector == a.permutation
54+
55+
56+
a = distribute([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
57+
b = distribute([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
58+
b = Adapt.adapt(FakeCuVector,b)
59+
60+
map(a,b) do val_a,val_b
61+
@test typeof(val_b) == FakeCuVector{typeof(val_a)}
62+
@test val_b.vector == val_a
63+
end
64+
end

test/debug_array/adapt_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module DebugArrayAdaptTests
2+
3+
using PartitionedArrays
4+
5+
include(joinpath("..","adapt_tests.jl"))
6+
7+
with_debug(adapt_tests)
8+
9+
end # module

test/debug_array/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ using PartitionedArrays
2323

2424
@testset "fem_example" begin include("fem_example.jl") end
2525

26+
@testset "adapt" begin include("adapt_tests.jl") end
27+
2628
end #module

test/mpi_array/adapt_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using MPI
2+
include("run_mpi_driver.jl")
3+
file = joinpath(@__DIR__,"drivers","adapt_tests.jl")
4+
run_mpi_driver(file;procs=4)
5+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module MPIArrayAdaptTests
2+
3+
using PartitionedArrays
4+
5+
include(joinpath("..","..","adapt_tests.jl"))
6+
7+
with_mpi(adapt_tests)
8+
9+
end # module

test/mpi_array/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ using PartitionedArrays
1313
@testset "p_timer_tests" begin include("p_timer_tests.jl") end
1414
@testset "fdm_example" begin include("fdm_example.jl") end
1515
@testset "fem_example" begin include("fem_example.jl") end
16+
@testset "adapt" begin include("adapt_tests.jl") end
1617

1718
end #module

0 commit comments

Comments
 (0)