diff --git a/docs/src/index.md b/docs/src/index.md index dbd6fae..847d2b3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -27,4 +27,6 @@ Pages = ["index.md"] @io2str @test_reference psnr_equality +default_equality_selector! +with_default_equality_selector ``` diff --git a/src/ReferenceTests.jl b/src/ReferenceTests.jl index 226ee96..f63e7b5 100644 --- a/src/ReferenceTests.jl +++ b/src/ReferenceTests.jl @@ -16,7 +16,9 @@ export @withcolor, @io2str, @test_reference, - psnr_equality + psnr_equality, + default_equality_selector!, + with_default_equality_selector include("utils.jl") include("test_reference.jl") diff --git a/src/test_reference.jl b/src/test_reference.jl index 98af86f..c0d0698 100644 --- a/src/test_reference.jl +++ b/src/test_reference.jl @@ -16,6 +16,8 @@ Arguments: * `filename::String`: _relative_ path to the file that contains the macro invocation. * `expr`: the actual content used to compare. * `by`: the equality test function. By default it is `isequal` if not explicitly stated. + A custom selector function that returns an equality function for a given `(reference, actual)` input + can be set permanently using `default_equality_selector!` or temporarily using `with_default_equality_selector`. * `format`: Force reading the file using a specific format # Types @@ -102,6 +104,60 @@ macro test_reference(reference, actual, kws...) expr end +function default_equality end + +const DEFAULT_EQUALITY_SELECTOR = Ref{Any}(default_equality) + +""" + default_equality_selector!() + default_equality_selector!(selector) + +Set `selector` as the global default function which is called in each reference test +as `eq = selector(reference, actual)` where `eq` is the selected equality function for `actual` and `reference` +which is then called as `eq(reference, actual)` to determine if the reference matches. +If no input argument is given, the selector function is reset to the default. +The equality function for a given `@test_reference` can be overridden using the `by` keyword. + +## Example + +```julia +# we have some image comparison function we want to use by default +custom_image_equality(reference, actual) = ... + +# our selector picks `custom_image_equality` for images and `isequal` for the rest +custom_selector(actual, reference) = isequal +custom_selector(reference::AbstractArray{<:Colorant}, actual::AbstractArray{<:Colorant}) = + custom_image_equality + +default_equality_selector!(custom_selector) + +# this test now uses `custom_image_equality` +@test_reference "image.png" some_image +``` +""" +function default_equality_selector!(selector = default_equality) + DEFAULT_EQUALITY_SELECTOR[] = selector + return +end + +""" + with_default_equality_selector(f, selector) + +Execute `f` while `selector` is set as the default global equality function selector using `default_equality_selector!`. +The previous default selector is restored automatically after `f` succeeds or fails. +""" +function with_default_equality_selector(f, selector) + old = DEFAULT_EQUALITY_SELECTOR[] + try + default_equality_selector!(selector) + result = f() + default_equality_selector!() + result + finally + DEFAULT_EQUALITY_SELECTOR[] = old + end +end + function test_reference( filename::AbstractString, raw_actual; by = nothing, render = nothing, format = nothing, kw...) @@ -145,7 +201,7 @@ function test_reference( if equiv === nothing # generally, `reference` and `actual` are of the same type after preprocessing - equiv = default_equality(reference, actual) + equiv = DEFAULT_EQUALITY_SELECTOR[](reference, actual) end if equiv(reference, actual) diff --git a/test/runtests.jl b/test/runtests.jl index e30914a..14ca567 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,8 @@ using Plots strip_summary(content::String) = join(split(content, "\n")[2:end], "\n") +struct CustomEqualityError <: Exception end + @testset "ReferenceTests" begin # load/create some example images @@ -201,5 +203,21 @@ end arr_float = [pi, pi/2, 1.0] @test_reference file Dict(:ar=>arr_float) by=comp end +@testset "Default equality selector" begin + custom_equality(reference, actual) = (r, a) -> throw(CustomEqualityError()) + + default_equality_selector!(custom_equality) + @test_throws CustomEqualityError @test_reference "references/camera.png" imresize(camera, (64,64)) + @test_reference "references/camera.png" imresize(camera, (64,64)) by=psnr_equality() + + default_equality_selector!() + @test_reference "references/camera.png" imresize(camera, (64,64)) + + with_default_equality_selector(custom_equality) do + @test_throws CustomEqualityError @test_reference "references/camera.png" imresize(camera, (64,64)) + @test_reference "references/camera.png" imresize(camera, (64,64)) by=psnr_equality() + end + @test_reference "references/camera.png" imresize(camera, (64,64)) +end end # top level testset