Skip to content

Commit 3c16583

Browse files
JLD2 extension take 2 (#597)
Co-authored-by: lassepe <[email protected]>
1 parent dc49c29 commit 3c16583

File tree

5 files changed

+43
-0
lines changed

5 files changed

+43
-0
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@ ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1515
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1616
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1717

18+
[weakdeps]
19+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
20+
21+
[extensions]
22+
JLD2Ext = "JLD2"
23+
1824
[compat]
1925
Adapt = "4.0"
2026
GPUArraysCore = "= 0.2.0"
27+
JLD2 = "0.4, 0.5"
2128
KernelAbstractions = "0.9.28"
2229
LLVM = "3.9, 4, 5, 6, 7, 8, 9"
2330
LinearAlgebra = "1"

ext/JLD2Ext.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module JLD2Ext
2+
3+
using GPUArrays: AbstractGPUArray
4+
using JLD2: JLD2
5+
6+
JLD2.writeas(::Type{<:AbstractGPUArray{T, N}}) where {T, N} = Array{T, N}
7+
8+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
44
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
55
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
6+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
67
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

test/testsuite.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using GPUArrays
1010

1111
using KernelAbstractions
1212
using LinearAlgebra
13+
using JLD2
1314
using Random
1415
using Test
1516

@@ -94,6 +95,7 @@ include("testsuite/random.jl")
9495
include("testsuite/uniformscaling.jl")
9596
include("testsuite/statistics.jl")
9697
include("testsuite/alloc_cache.jl")
98+
include("testsuite/jld2ext.jl")
9799

98100
"""
99101
Runs the entire GPUArrays test suite on array type `AT`

test/testsuite/jld2ext.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@testsuite "ext/jld2" (AT, eltypes) -> begin
2+
mktempdir() do dir
3+
for ET in eltypes
4+
@testset "$ET" begin
5+
# Test with different array sizes and dimensions
6+
for dims in ((2,), (2, 2), (2, 2, 2))
7+
# Create a random array
8+
x = AT(rand(ET, dims...))
9+
10+
# Save to a temporary file
11+
file = joinpath(dir, "test.jld2")
12+
13+
# Save and load
14+
JLD2.save_object(file, x)
15+
y = JLD2.load_object(file)
16+
17+
# Verify the loaded array matches the original
18+
@test y isa AT{ET, length(dims)}
19+
@test size(y) == size(x)
20+
@test x == y
21+
end
22+
end
23+
end
24+
end
25+
end

0 commit comments

Comments
 (0)