diff --git a/Project.toml b/Project.toml index ef0e644..ced5011 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" @@ -21,5 +20,4 @@ Random = "<0.0.1, 1" Reexport = "1" Statistics = "<0.0.1, 1" XAIBase = "4" -Zygote = "0.6" julia = "1.10" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index d2a943a..e6b0133 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -4,6 +4,7 @@ ExplainableAI = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BenchmarkTools = "1" diff --git a/benchmark/bench_jogger.jl b/benchmark/bench_jogger.jl index e786975..bda981d 100644 --- a/benchmark/bench_jogger.jl +++ b/benchmark/bench_jogger.jl @@ -1,4 +1,5 @@ using BenchmarkTools +using Zygote using Flux using ExplainableAI diff --git a/docs/Project.toml b/docs/Project.toml index 3fff181..cab982e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,6 +11,7 @@ ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -VisionHeatmaps = "1.4" \ No newline at end of file +VisionHeatmaps = "1.4" diff --git a/docs/src/literate/augmentations.jl b/docs/src/literate/augmentations.jl index 0b7438b..939f43f 100644 --- a/docs/src/literate/augmentations.jl +++ b/docs/src/literate/augmentations.jl @@ -8,6 +8,7 @@ # and start out by loading the same pre-trained LeNet5 model and MNIST input data: using ExplainableAI using VisionHeatmaps +using Zygote using Flux using BSON # hide diff --git a/docs/src/literate/example.jl b/docs/src/literate/example.jl index 7efa543..3dd1d4b 100644 --- a/docs/src/literate/example.jl +++ b/docs/src/literate/example.jl @@ -7,7 +7,6 @@ # For this first example, we already have loaded a pre-trained LeNet5 model # to look at explanations on the MNIST dataset. -using ExplainableAI using Flux using BSON # hide @@ -41,8 +40,11 @@ input = reshape(x, 28, 28, 1, :); #md # (width, height, channels, batch), which is Flux.jl's convention. # ## Explanations -# We can now select an analyzer of our choice and call [`analyze`](@ref) -# to get an [`Explanation`](@ref): +# We can now select an analyzer of our choice and call [`analyze`](@ref) to get an [`Explanation`](@ref). +# Note that for gradient-based optimizers, a backend for automatic differentiation must be loaded, by default [Zygote.jl](https://github.com/FluxML/Zygote.jl): +using ExplainableAI +using Zygote + analyzer = InputTimesGradient(model) expl = analyze(input, analyzer); diff --git a/docs/src/literate/heatmapping.jl b/docs/src/literate/heatmapping.jl index 9289dff..5351571 100644 --- a/docs/src/literate/heatmapping.jl +++ b/docs/src/literate/heatmapping.jl @@ -10,6 +10,7 @@ # We start out by loading the same pre-trained LeNet5 model and MNIST input data: using ExplainableAI using VisionHeatmaps +using Zygote using Flux using BSON # hide diff --git a/src/ExplainableAI.jl b/src/ExplainableAI.jl index 567971f..2ed8ea9 100644 --- a/src/ExplainableAI.jl +++ b/src/ExplainableAI.jl @@ -11,7 +11,6 @@ using Random: AbstractRNG, GLOBAL_RNG # Automatic differentiation using ADTypes: AbstractADType, AutoZygote using DifferentiationInterface: value_and_pullback -using Zygote const DEFAULT_AD_BACKEND = AutoZygote() include("bibliography.jl") diff --git a/test/Project.toml b/test/Project.toml index a616a45..a8c6b4b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,3 +13,4 @@ ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/test_batches.jl b/test/test_batches.jl index 832d86e..bcd1a70 100644 --- a/test/test_batches.jl +++ b/test/test_batches.jl @@ -1,4 +1,5 @@ using ExplainableAI +using Zygote using Test using Flux diff --git a/test/test_cnn.jl b/test/test_cnn.jl index b8ff8f3..a49d7d8 100644 --- a/test/test_cnn.jl +++ b/test/test_cnn.jl @@ -1,4 +1,5 @@ using ExplainableAI +using Zygote using Test using ReferenceTests