diff --git a/Project.toml b/Project.toml index d94c8ad..2427a68 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.2.0" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -14,6 +15,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" BSON = "0.3.4" Dates = "1.10" ExpressionExplorer = "1.1.3" +JLD2 = "0.4, 0.5" Logging = "1.10" MacroTools = "0.5.16" julia = "1.10" diff --git a/README.md b/README.md index 54e5162..e5ac56e 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,55 @@ julia> a, b # b was overwritten in the first let block but not the second > This should generally work, but may not always catch all the variables - check the list > printed out to make sure. The function form `cache` can be used for more control. +## File formats + +CacheVariables.jl supports two file formats, determined by the file extension: + +- `.bson`: save using [BSON.jl](https://github.com/JuliaIO/BSON.jl), + which is a lightweight format that works well for many Julia objects. +- `.jld2`: save using [JLD2.jl](https://github.com/JuliaIO/JLD2.jl), + which may provide better support for arbitrary Julia types. + +Simply change the file extension to switch between formats: + +```julia +# Using BSON format +cache("results.bson") do + # cached computations +end + +# Using JLD2 format +cache("results.jld2") do + # cached computations +end +``` + +The same works for the macro form: + +```julia +# Using BSON format +@cache "results.bson" begin + # cached computations +end + +# Using JLD2 format +@cache "results.jld2" begin + # cached computations +end +``` + +The module context for loading BSON files can be set via the `bson_mod` keyword argument: + +```julia +cache("data.bson"; bson_mod = @__MODULE__) do + # cached computations +end +``` + +This may be useful when working in modules or in Pluto notebooks +(see the [BSON.jl documentation](https://github.com/JuliaIO/BSON.jl?tab=readme-ov-file#loading-custom-data-types-within-modules) +for more detail). + ## Caching the results of a sweep It can be common to need to cache the results of a large sweep (e.g., over parameters or trials of a simulation). diff --git a/src/CacheVariables.jl b/src/CacheVariables.jl index 59604c2..63fae7f 100755 --- a/src/CacheVariables.jl +++ b/src/CacheVariables.jl @@ -1,8 +1,9 @@ module CacheVariables -using BSON +using BSON: BSON using Dates: UTC, now using ExpressionExplorer: compute_symbols_state +using JLD2: JLD2 using Logging: @info using MacroTools: @capture diff --git a/src/function.jl b/src/function.jl index c0f4516..f6f1134 100644 --- a/src/function.jl +++ b/src/function.jl @@ -11,9 +11,13 @@ In addition to the output of `f()`, the following metadata is saved for the run: - Time when run (in UTC) - Runtime of code (in seconds) -If `path` is set to `nothing`, caching is disabled and `f()` is simply run. +The file extension of `path` determines the file format used: +`.bson` for [BSON.jl](https://github.com/JuliaIO/BSON.jl) and +`.jld2` for [JLD2.jl](https://github.com/JuliaIO/JLD2.jl). +The `path` can also be set to `nothing` to disable caching and simply run `f()`. This can be useful for conditionally caching the results, e.g., to only cache a sweep when the full set is ready. + If `overwrite` is set to true, existing cache files will be overwritten with the results (and metadata) from a "fresh" call to `f()`. If necessary, the module to use for BSON can be set with `bson_mod`. @@ -50,10 +54,14 @@ julia> cache(nothing) do (a = "a very time-consuming quantity to compute", b = "a very long simulation to run") ``` """ -function cache(@nospecialize(f), path; overwrite = false, bson_mod = Main) - if isnothing(path) - return f() - elseif !ispath(path) || overwrite +function cache(@nospecialize(f), path::AbstractString; overwrite = false, bson_mod = Main) + # Check file extension + ext = splitext(path)[2] + (ext == ".bson" || ext == ".jld2") || + throw(ArgumentError("Only `.bson` and `.jld2` files are supported.")) + + # Save, overwrite or load + if !ispath(path) || overwrite # Collect metadata and run function version = VERSION whenrun = now(UTC) @@ -71,11 +79,39 @@ function cache(@nospecialize(f), path; overwrite = false, bson_mod = Main) # Save metadata and output mkpath(dirname(path)) - bson(path; version, whenrun, runtime, output) + if ext == ".bson" + data = Dict( + :version => version, + :whenrun => whenrun, + :runtime => runtime, + :output => output, + ) + BSON.bson(path, data) + elseif ext == ".jld2" + data = Dict( + "version" => version, + "whenrun" => whenrun, + "runtime" => runtime, + "output" => output, + ) + JLD2.save(path, data) + end return output else # Load metadata and output - (; version, whenrun, runtime, output) = NamedTuple(BSON.load(path, bson_mod)) + if ext == ".bson" + data = BSON.load(path, bson_mod) + version = data[:version] + whenrun = data[:whenrun] + runtime = data[:runtime] + output = data[:output] + elseif ext == ".jld2" + data = JLD2.load(path) + version = data["version"] + whenrun = data["whenrun"] + runtime = data["runtime"] + output = data["output"] + end # Log @info message @info """ @@ -88,3 +124,4 @@ function cache(@nospecialize(f), path; overwrite = false, bson_mod = Main) return output end end +cache(@nospecialize(f), ::Nothing; kwargs...) = f() diff --git a/test/runtests.jl b/test/runtests.jl index 7731dec..8f32397 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using TestItemRunner -## Test save and load behavior of @cache macro -@testitem "@cache save and load" begin +## Test save and load behavior of @cache macro with BSON format +@testitem "@cache save and load (BSON)" begin mktempdir(@__DIR__; prefix = "temp_") do dirpath path = joinpath(dirpath, "test.bson") @@ -302,8 +302,8 @@ end end end -## Test save and load behavior of cache function -@testitem "cache save and load" begin +## Test save and load behavior of cache function with BSON format +@testitem "cache save and load (BSON)" begin using BSON, Dates mktempdir(@__DIR__; prefix = "temp_") do dirpath funcpath = joinpath(dirpath, "functest.bson") @@ -341,7 +341,7 @@ end return (; x = x, y = y, z = z) end - # 6. Load output + # 6. Load the output out = cache(funcpath) do x = collect(1:3) y = 4 @@ -371,9 +371,9 @@ end @test out == (; x = [1, 2, 3], y = 4, z = "test") end -## Test cache in a module -@testitem "cache in a module" begin - module MyCacheModule +## Test cache in a module (BSON) +@testitem "cache in a module (BSON)" begin + module MyCacheModuleBSON using CacheVariables, Test, DataFrames mktempdir(@__DIR__; prefix = "temp_") do dirpath @@ -397,4 +397,164 @@ end end end +## Test save and load behavior of cache function with JLD2 format +@testitem "cache save and load (JLD2)" begin + using JLD2, Dates + mktempdir(@__DIR__; prefix = "temp_") do dirpath + funcpath = joinpath(dirpath, "functest.jld2") + + # 1. Verify log messages for saving + log = (:info, r"^Saved cached values to .+\.") + @test_logs log cache(funcpath) do + x = collect(1:3) + y = 4 + z = "test" + return (; x = x, y = y, z = z) + end + + # 2. Delete cache and run again + rm(funcpath) + out = cache(funcpath) do + x = collect(1:3) + y = 4 + z = "test" + return (; x = x, y = y, z = z) + end + + # 3. Verify the output + @test out == (; x = [1, 2, 3], y = 4, z = "test") + + # 4. Reset the output + out = nothing + + # 5. Verify log messages for loading + log = (:info, r"^Loaded cached values from .+\.") + @test_logs log cache(funcpath) do + x = collect(1:3) + y = 4 + z = "test" + return (; x = x, y = y, z = z) + end + + # 6. Load the output + out = cache(funcpath) do + x = collect(1:3) + y = 4 + z = "test" + return (; x = x, y = y, z = z) + end + + # 7. Verify the output + @test out == (; x = [1, 2, 3], y = 4, z = "test") + + # 8. Verify the metadata + data = JLD2.load(funcpath) + @test data["version"] isa VersionNumber + @test data["whenrun"] isa Dates.DateTime + @test data["runtime"] isa Real && data["runtime"] >= 0 + end +end + +## Test save and load behavior of @cache macro with JLD2 format +@testitem "@cache save and load (JLD2)" begin + mktempdir(@__DIR__; prefix = "temp_") do dirpath + path = joinpath(dirpath, "test.jld2") + + # 1. Verify log messages for saving + log1 = (:info, "Variable assignments found: x, y, z") + log2 = (:info, r"^Saved cached values to .+\.") + @test_logs log1 log2 (@cache path begin + x = collect(1:3) + y = 4 + z = "test" + "final output" + end) + + # 2. Delete cache and run again + rm(path) + out = @cache path begin + x = collect(1:3) + y = 4 + z = "test" + "final output" + end + + # 3. Verify that the variables enter the workspace correctly + @test x == [1, 2, 3] + @test y == 4 + @test z == "test" + @test out == "final output" + + # 4. Reset the variables + x = y = z = out = nothing + + # 5. Verify log messages for loading + log1 = (:info, "Variable assignments found: x, y, z") + log2 = (:info, r"^Loaded cached values from .+\.") + @test_logs log1 log2 (@cache path begin + x = collect(1:3) + y = 4 + z = "test" + "final output" + end) + + # 6. Load variables + out = @cache path begin + x = collect(1:3) + y = 4 + z = "test" + "final output" + end + + # 7. Verify that the variables enter the workspace correctly + @test x == [1, 2, 3] + @test y == 4 + @test z == "test" + @test out == "final output" + end +end + +## Test cache in a module (JLD2) +@testitem "cache in a module (JLD2)" begin + module MyCacheModuleJLD2 + using CacheVariables, Test, DataFrames + + mktempdir(@__DIR__; prefix = "temp_") do dirpath + modpath = joinpath(dirpath, "funcmodtest.jld2") + + # 1. Save and check the output + out = cache(modpath) do + return DataFrame(; a = 1:10, b = 'a':'j') + end + @test out == DataFrame(; a = 1:10, b = 'a':'j') + + # 2. Reset the output + out = nothing + + # 3. Load and check the output + out = cache(modpath) do + return DataFrame(; a = 1:10, b = 'a':'j') + end + @test out == DataFrame(; a = 1:10, b = 'a':'j') + end + end +end + +## Test error handling for unsupported file extensions +@testitem "unsupported file extensions" begin + mktempdir(@__DIR__; prefix = "temp_") do dirpath + badpath = joinpath(dirpath, "test.mat") + + # Test with function form + @test_throws ArgumentError cache(badpath) do + return 42 + end + + # Test with macro form + @test_throws ArgumentError @cache badpath begin + x = 1 + end + end +end + @run_package_tests