Skip to content

Commit 9ad7299

Browse files
Copilotdahong67
andauthored
Add JLD2 file format support (#26)
* Initial plan * Add JLD2 dependency and support in function form Co-authored-by: dahong67 <9384655+dahong67@users.noreply.github.com> * Add JLD2 tests and improve error handling Co-authored-by: dahong67 <9384655+dahong67@users.noreply.github.com> * Add file formats section to README Co-authored-by: dahong67 <9384655+dahong67@users.noreply.github.com> * Refactor to eliminate duplicate validation code Co-authored-by: dahong67 <9384655+dahong67@users.noreply.github.com> * Use explicit extension checks for improved safety Co-authored-by: dahong67 <9384655+dahong67@users.noreply.github.com> * Review and revise PR --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: dahong67 <9384655+dahong67@users.noreply.github.com> Co-authored-by: David Hong <hong@udel.edu>
1 parent 5b86c69 commit 9ad7299

File tree

5 files changed

+265
-16
lines changed

5 files changed

+265
-16
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ version = "0.2.0"
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
88
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
99
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
10+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1011
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1112
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1213

1314
[compat]
1415
BSON = "0.3.4"
1516
Dates = "1.10"
1617
ExpressionExplorer = "1.1.3"
18+
JLD2 = "0.4, 0.5"
1719
Logging = "1.10"
1820
MacroTools = "0.5.16"
1921
julia = "1.10"

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,55 @@ julia> a, b # b was overwritten in the first let block but not the second
168168
> This should generally work, but may not always catch all the variables - check the list
169169
> printed out to make sure. The function form `cache` can be used for more control.
170170
171+
## File formats
172+
173+
CacheVariables.jl supports two file formats, determined by the file extension:
174+
175+
- `.bson`: save using [BSON.jl](https://github.com/JuliaIO/BSON.jl),
176+
which is a lightweight format that works well for many Julia objects.
177+
- `.jld2`: save using [JLD2.jl](https://github.com/JuliaIO/JLD2.jl),
178+
which may provide better support for arbitrary Julia types.
179+
180+
Simply change the file extension to switch between formats:
181+
182+
```julia
183+
# Using BSON format
184+
cache("results.bson") do
185+
# cached computations
186+
end
187+
188+
# Using JLD2 format
189+
cache("results.jld2") do
190+
# cached computations
191+
end
192+
```
193+
194+
The same works for the macro form:
195+
196+
```julia
197+
# Using BSON format
198+
@cache "results.bson" begin
199+
# cached computations
200+
end
201+
202+
# Using JLD2 format
203+
@cache "results.jld2" begin
204+
# cached computations
205+
end
206+
```
207+
208+
The module context for loading BSON files can be set via the `bson_mod` keyword argument:
209+
210+
```julia
211+
cache("data.bson"; bson_mod = @__MODULE__) do
212+
# cached computations
213+
end
214+
```
215+
216+
This may be useful when working in modules or in Pluto notebooks
217+
(see the [BSON.jl documentation](https://github.com/JuliaIO/BSON.jl?tab=readme-ov-file#loading-custom-data-types-within-modules)
218+
for more detail).
219+
171220
## Caching the results of a sweep
172221
173222
It can be common to need to cache the results of a large sweep (e.g., over parameters or trials of a simulation).

src/CacheVariables.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module CacheVariables
22

3-
using BSON
3+
using BSON: BSON
44
using Dates: UTC, now
55
using ExpressionExplorer: compute_symbols_state
6+
using JLD2: JLD2
67
using Logging: @info
78
using MacroTools: @capture
89

src/function.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ In addition to the output of `f()`, the following metadata is saved for the run:
1111
- Time when run (in UTC)
1212
- Runtime of code (in seconds)
1313
14-
If `path` is set to `nothing`, caching is disabled and `f()` is simply run.
14+
The file extension of `path` determines the file format used:
15+
`.bson` for [BSON.jl](https://github.com/JuliaIO/BSON.jl) and
16+
`.jld2` for [JLD2.jl](https://github.com/JuliaIO/JLD2.jl).
17+
The `path` can also be set to `nothing` to disable caching and simply run `f()`.
1518
This can be useful for conditionally caching the results,
1619
e.g., to only cache a sweep when the full set is ready.
20+
1721
If `overwrite` is set to true, existing cache files will be overwritten
1822
with the results (and metadata) from a "fresh" call to `f()`.
1923
If necessary, the module to use for BSON can be set with `bson_mod`.
@@ -50,10 +54,14 @@ julia> cache(nothing) do
5054
(a = "a very time-consuming quantity to compute", b = "a very long simulation to run")
5155
```
5256
"""
53-
function cache(@nospecialize(f), path; overwrite = false, bson_mod = Main)
54-
if isnothing(path)
55-
return f()
56-
elseif !ispath(path) || overwrite
57+
function cache(@nospecialize(f), path::AbstractString; overwrite = false, bson_mod = Main)
58+
# Check file extension
59+
ext = splitext(path)[2]
60+
(ext == ".bson" || ext == ".jld2") ||
61+
throw(ArgumentError("Only `.bson` and `.jld2` files are supported."))
62+
63+
# Save, overwrite or load
64+
if !ispath(path) || overwrite
5765
# Collect metadata and run function
5866
version = VERSION
5967
whenrun = now(UTC)
@@ -71,11 +79,39 @@ function cache(@nospecialize(f), path; overwrite = false, bson_mod = Main)
7179

7280
# Save metadata and output
7381
mkpath(dirname(path))
74-
bson(path; version, whenrun, runtime, output)
82+
if ext == ".bson"
83+
data = Dict(
84+
:version => version,
85+
:whenrun => whenrun,
86+
:runtime => runtime,
87+
:output => output,
88+
)
89+
BSON.bson(path, data)
90+
elseif ext == ".jld2"
91+
data = Dict(
92+
"version" => version,
93+
"whenrun" => whenrun,
94+
"runtime" => runtime,
95+
"output" => output,
96+
)
97+
JLD2.save(path, data)
98+
end
7599
return output
76100
else
77101
# Load metadata and output
78-
(; version, whenrun, runtime, output) = NamedTuple(BSON.load(path, bson_mod))
102+
if ext == ".bson"
103+
data = BSON.load(path, bson_mod)
104+
version = data[:version]
105+
whenrun = data[:whenrun]
106+
runtime = data[:runtime]
107+
output = data[:output]
108+
elseif ext == ".jld2"
109+
data = JLD2.load(path)
110+
version = data["version"]
111+
whenrun = data["whenrun"]
112+
runtime = data["runtime"]
113+
output = data["output"]
114+
end
79115

80116
# Log @info message
81117
@info """
@@ -88,3 +124,4 @@ function cache(@nospecialize(f), path; overwrite = false, bson_mod = Main)
88124
return output
89125
end
90126
end
127+
cache(@nospecialize(f), ::Nothing; kwargs...) = f()

test/runtests.jl

Lines changed: 168 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using TestItemRunner
22

3-
## Test save and load behavior of @cache macro
4-
@testitem "@cache save and load" begin
3+
## Test save and load behavior of @cache macro with BSON format
4+
@testitem "@cache save and load (BSON)" begin
55
mktempdir(@__DIR__; prefix = "temp_") do dirpath
66
path = joinpath(dirpath, "test.bson")
77

@@ -302,8 +302,8 @@ end
302302
end
303303
end
304304

305-
## Test save and load behavior of cache function
306-
@testitem "cache save and load" begin
305+
## Test save and load behavior of cache function with BSON format
306+
@testitem "cache save and load (BSON)" begin
307307
using BSON, Dates
308308
mktempdir(@__DIR__; prefix = "temp_") do dirpath
309309
funcpath = joinpath(dirpath, "functest.bson")
@@ -341,7 +341,7 @@ end
341341
return (; x = x, y = y, z = z)
342342
end
343343

344-
# 6. Load output
344+
# 6. Load the output
345345
out = cache(funcpath) do
346346
x = collect(1:3)
347347
y = 4
@@ -371,9 +371,9 @@ end
371371
@test out == (; x = [1, 2, 3], y = 4, z = "test")
372372
end
373373

374-
## Test cache in a module
375-
@testitem "cache in a module" begin
376-
module MyCacheModule
374+
## Test cache in a module (BSON)
375+
@testitem "cache in a module (BSON)" begin
376+
module MyCacheModuleBSON
377377
using CacheVariables, Test, DataFrames
378378

379379
mktempdir(@__DIR__; prefix = "temp_") do dirpath
@@ -397,4 +397,164 @@ end
397397
end
398398
end
399399

400+
## Test save and load behavior of cache function with JLD2 format
401+
@testitem "cache save and load (JLD2)" begin
402+
using JLD2, Dates
403+
mktempdir(@__DIR__; prefix = "temp_") do dirpath
404+
funcpath = joinpath(dirpath, "functest.jld2")
405+
406+
# 1. Verify log messages for saving
407+
log = (:info, r"^Saved cached values to .+\.")
408+
@test_logs log cache(funcpath) do
409+
x = collect(1:3)
410+
y = 4
411+
z = "test"
412+
return (; x = x, y = y, z = z)
413+
end
414+
415+
# 2. Delete cache and run again
416+
rm(funcpath)
417+
out = cache(funcpath) do
418+
x = collect(1:3)
419+
y = 4
420+
z = "test"
421+
return (; x = x, y = y, z = z)
422+
end
423+
424+
# 3. Verify the output
425+
@test out == (; x = [1, 2, 3], y = 4, z = "test")
426+
427+
# 4. Reset the output
428+
out = nothing
429+
430+
# 5. Verify log messages for loading
431+
log = (:info, r"^Loaded cached values from .+\.")
432+
@test_logs log cache(funcpath) do
433+
x = collect(1:3)
434+
y = 4
435+
z = "test"
436+
return (; x = x, y = y, z = z)
437+
end
438+
439+
# 6. Load the output
440+
out = cache(funcpath) do
441+
x = collect(1:3)
442+
y = 4
443+
z = "test"
444+
return (; x = x, y = y, z = z)
445+
end
446+
447+
# 7. Verify the output
448+
@test out == (; x = [1, 2, 3], y = 4, z = "test")
449+
450+
# 8. Verify the metadata
451+
data = JLD2.load(funcpath)
452+
@test data["version"] isa VersionNumber
453+
@test data["whenrun"] isa Dates.DateTime
454+
@test data["runtime"] isa Real && data["runtime"] >= 0
455+
end
456+
end
457+
458+
## Test save and load behavior of @cache macro with JLD2 format
459+
@testitem "@cache save and load (JLD2)" begin
460+
mktempdir(@__DIR__; prefix = "temp_") do dirpath
461+
path = joinpath(dirpath, "test.jld2")
462+
463+
# 1. Verify log messages for saving
464+
log1 = (:info, "Variable assignments found: x, y, z")
465+
log2 = (:info, r"^Saved cached values to .+\.")
466+
@test_logs log1 log2 (@cache path begin
467+
x = collect(1:3)
468+
y = 4
469+
z = "test"
470+
"final output"
471+
end)
472+
473+
# 2. Delete cache and run again
474+
rm(path)
475+
out = @cache path begin
476+
x = collect(1:3)
477+
y = 4
478+
z = "test"
479+
"final output"
480+
end
481+
482+
# 3. Verify that the variables enter the workspace correctly
483+
@test x == [1, 2, 3]
484+
@test y == 4
485+
@test z == "test"
486+
@test out == "final output"
487+
488+
# 4. Reset the variables
489+
x = y = z = out = nothing
490+
491+
# 5. Verify log messages for loading
492+
log1 = (:info, "Variable assignments found: x, y, z")
493+
log2 = (:info, r"^Loaded cached values from .+\.")
494+
@test_logs log1 log2 (@cache path begin
495+
x = collect(1:3)
496+
y = 4
497+
z = "test"
498+
"final output"
499+
end)
500+
501+
# 6. Load variables
502+
out = @cache path begin
503+
x = collect(1:3)
504+
y = 4
505+
z = "test"
506+
"final output"
507+
end
508+
509+
# 7. Verify that the variables enter the workspace correctly
510+
@test x == [1, 2, 3]
511+
@test y == 4
512+
@test z == "test"
513+
@test out == "final output"
514+
end
515+
end
516+
517+
## Test cache in a module (JLD2)
518+
@testitem "cache in a module (JLD2)" begin
519+
module MyCacheModuleJLD2
520+
using CacheVariables, Test, DataFrames
521+
522+
mktempdir(@__DIR__; prefix = "temp_") do dirpath
523+
modpath = joinpath(dirpath, "funcmodtest.jld2")
524+
525+
# 1. Save and check the output
526+
out = cache(modpath) do
527+
return DataFrame(; a = 1:10, b = 'a':'j')
528+
end
529+
@test out == DataFrame(; a = 1:10, b = 'a':'j')
530+
531+
# 2. Reset the output
532+
out = nothing
533+
534+
# 3. Load and check the output
535+
out = cache(modpath) do
536+
return DataFrame(; a = 1:10, b = 'a':'j')
537+
end
538+
@test out == DataFrame(; a = 1:10, b = 'a':'j')
539+
end
540+
end
541+
end
542+
543+
## Test error handling for unsupported file extensions
544+
@testitem "unsupported file extensions" begin
545+
mktempdir(@__DIR__; prefix = "temp_") do dirpath
546+
badpath = joinpath(dirpath, "test.mat")
547+
548+
# Test with function form
549+
@test_throws ArgumentError cache(badpath) do
550+
return 42
551+
end
552+
553+
# Test with macro form
554+
@test_throws ArgumentError @cache badpath begin
555+
x = 1
556+
end
557+
end
558+
end
559+
400560
@run_package_tests

0 commit comments

Comments
 (0)