Skip to content

Commit e0168ce

Browse files
Merge pull request #488 from simenhu/feature/adapt_on_arraypartition
Adapt rule for ArrayPartition
2 parents ec2ee49 + 4425ce9 commit e0168ce

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

src/array_partition.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,3 +579,7 @@ end
579579
end
580580
return sum_expr
581581
end
582+
583+
function Adapt.adapt_structure(to, ap::ArrayPartition)
584+
ArrayPartition(map(x -> Adapt.adapt(to, x), ap.x)...)
585+
end

test/gpu/arraypartition_gpu.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, CUDA, Test
1+
using RecursiveArrayTools, CUDA, Test, Adapt
22
CUDA.allowscalar(false)
33

44
# Test indexing with colon
@@ -21,3 +21,23 @@ fill!(pA, false)
2121
a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
2222
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
2323
@. a + b
24+
25+
# Test adapt from ArrayPartition with CuArrays to ArrayPartition with CPU arrays
26+
27+
a = CuArray(Float64.([1., 2., 3., 4.]))
28+
b = CuArray(Float64.([1., 2., 3., 4.]))
29+
part_a_gpu = ArrayPartition(a, b)
30+
part_a = adapt(Array{Float32}, part_a_gpu)
31+
32+
c = Float32.([1., 2., 3., 4.])
33+
d = Float32.([1., 2., 3., 4.])
34+
part_b = ArrayPartition(c, d)
35+
36+
@test part_a == part_b # Test equality
37+
38+
for i in 1:length(part_a.x)
39+
sub_a = part_a.x[i]
40+
sub_b = part_b.x[i]
41+
@test sub_a == sub_b # Test for value equality in sub-arrays
42+
@test typeof(sub_a) === typeof(sub_b) # Test type equality
43+
end

test/partitions_test.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, Test, Statistics, ArrayInterface
1+
using RecursiveArrayTools, Test, Statistics, ArrayInterface, Adapt
22

33
@test length(ArrayPartition()) == 0
44
@test isempty(ArrayPartition())
@@ -306,3 +306,22 @@ end
306306
copyto!(u, ArrayPartition(1.0, -1.2))
307307
@test u == [1.0, -1.2]
308308
end
309+
310+
# Test adapt on ArrayPartition from Float64 to Float32 arrays
311+
a = Float64.([1., 2., 3., 4.])
312+
b = Float64.([1., 2., 3., 4.])
313+
part_a_64 = ArrayPartition(a, b)
314+
part_a = adapt(Array{Float32}, part_a_64)
315+
316+
c = Float32.([1., 2., 3., 4.])
317+
d = Float32.([1., 2., 3., 4.])
318+
part_b = ArrayPartition(c, d)
319+
320+
@test part_a == part_b # Test equality of partitions
321+
322+
for i in 1:length(part_a.x)
323+
sub_a = part_a.x[i]
324+
sub_b = part_b.x[i]
325+
@test sub_a == sub_b # Test for value equality
326+
@test typeof(sub_a) === typeof(sub_b) # Test type equality
327+
end

0 commit comments

Comments
 (0)