From e525614ab49e2d11ae9538f3e59fce9efc293555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 9 Dec 2020 19:43:25 +0100 Subject: [PATCH 01/48] Use broadcasting instead of map for kerneldiagmatrix --- src/matrix/kernelmatrix.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index a7262a243..5324cdd69 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -71,10 +71,9 @@ function kerneldiagmatrix!( return map!(κ, x, y) end -kerneldiagmatrix(κ::Kernel, x::AbstractVector) = map(x -> κ(x, x), x) - -kerneldiagmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, y) +kerneldiagmatrix(κ::Kernel, x::AbstractVector) = κ.(x, x) +kerneldiagmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) = κ.(x, y) # From e56492a349924efabf8757cf4ce778fa891eada9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 9 Dec 2020 19:43:47 +0100 Subject: [PATCH 02/48] Removed method for transformedkernel --- src/kernels/transformedkernel.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 19c369160..eec13a97d 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -78,10 +78,6 @@ function kernelmatrix!( return kernelmatrix!(K, kernel(κ), _map(κ.transform, x), _map(κ.transform, y)) end -function kerneldiagmatrix(κ::TransformedKernel, x::AbstractVector) - return kerneldiagmatrix(κ.kernel, _map(κ.transform, x)) -end - function kernelmatrix(κ::TransformedKernel, x::AbstractVector) return kernelmatrix(kernel(κ), _map(κ.transform, x)) end From 35a63069b6930b6707914d16f47d18c5ca7b12e3 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Mon, 14 Dec 2020 16:35:18 +0100 Subject: [PATCH 03/48] Restored functions and applied suggestions --- src/kernels/transformedkernel.jl | 12 ++++++++++++ src/matrix/kernelmatrix.jl | 12 ++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index eec13a97d..e8403f7b1 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -68,6 +68,10 @@ function kerneldiagmatrix!(K::AbstractVector, κ::TransformedKernel, x::Abstract return kerneldiagmatrix!(K, κ.kernel, _map(κ.transform, x)) end +function kerneldiagmatrix!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector) + return kerneldiagmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) +end + function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector) return kernelmatrix!(K, kernel(κ), _map(κ.transform, x)) end @@ -78,6 +82,14 @@ function kernelmatrix!( return kernelmatrix!(K, kernel(κ), _map(κ.transform, x), _map(κ.transform, y)) end +function kerneldiagmatrix(κ::TransformedKernel, x::AbstractVector) + return kerneldiagmatrix(κ.kernel, _map(κ.transform, x)) +end + +function kerneldiagmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) + return kerneldiagmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) +end + function kernelmatrix(κ::TransformedKernel, x::AbstractVector) return kernelmatrix(kernel(κ), _map(κ.transform, x)) end diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index 5324cdd69..5970a5517 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -71,9 +71,9 @@ function kerneldiagmatrix!( return map!(κ, x, y) end -kerneldiagmatrix(κ::Kernel, x::AbstractVector) = κ.(x, x) +kerneldiagmatrix(κ::Kernel, x::AbstractVector) = map(x -> κ(x, x), x) -kerneldiagmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) = κ.(x, y) +kerneldiagmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, y) # @@ -103,6 +103,14 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) return map(d -> kappa(κ, d), pairwise(metric(κ), x, y)) end +function kerneldiagmatrix(κ::SimpleKernel, x::AbstractVector) + return map(x -> κ(x, x), x) +end + +function kerneldiagmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) + return map(d -> kappa(κ, d), map(metric(κ), x, y)) +end + # From 25e5efd72896e97a976d6daa238a20980aee0cfd Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Mon, 14 Dec 2020 16:35:30 +0100 Subject: [PATCH 04/48] Added tests for diagmatrix --- test/utils_AD.jl | 180 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 test/utils_AD.jl diff --git a/test/utils_AD.jl b/test/utils_AD.jl new file mode 100644 index 000000000..c86e8e218 --- /dev/null +++ b/test/utils_AD.jl @@ -0,0 +1,180 @@ + +const FDM = FiniteDifferences.central_fdm(5, 1) + +gradient(f, s::Symbol, args) = gradient(f, Val(s), args) + +function gradient(f, ::Val{:Zygote}, args) + g = first(Zygote.gradient(f, args)) + if isnothing(g) + if args isa AbstractArray{<:Real} + return zeros(size(args)) # To respect the same output as other ADs + else + return zeros.(size.(args)) + end + else + return g + end +end + +function gradient(f, ::Val{:ForwardDiff}, args) + ForwardDiff.gradient(f, args) +end + +function gradient(f, ::Val{:ReverseDiff}, args) + ReverseDiff.gradient(f, args) +end + +function gradient(f, ::Val{:FiniteDiff}, args) + first(FiniteDifferences.grad(FDM, f, args)) +end + +function compare_gradient(f, AD::Symbol, args) + grad_AD = gradient(f, AD, args) + grad_FD = gradient(f, :FiniteDiff, args) + @test grad_AD ≈ grad_FD atol=1e-8 rtol=1e-5 +end + +testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim)) +testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim)) +testdiagfunction(k, A, dim) = sum(kerneldiagmatrix(k, A, obsdim = dim)) +testdiagfunction(k, A, B, dim) = sum(kerneldiagmatrix(k, A, B, obsdim = dim)) + +function test_ADs(kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3]) + test_fd = test_FiniteDiff(kernelfunction, args, dims) + if !test_fd.anynonpass + for AD in ADs + test_AD(AD, kernelfunction, args, dims) + end + end +end + +function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3]) + # Init arguments : + k = if args === nothing + kernelfunction() + else + kernelfunction(args) + end + rng = MersenneTwister(42) + @testset "FiniteDifferences" begin + if k isa SimpleKernel + for d in log.([eps(), rand(rng)]) + @test_nowarn gradient(:FiniteDiff, [d]) do x + kappa(k, exp(first(x))) + end + end + end + ## Testing Kernel Functions + x = rand(rng, dims[1]) + y = rand(rng, dims[1]) + @test_nowarn gradient(:FiniteDiff, x) do x + k(x, y) + end + if !(args === nothing) + @test_nowarn gradient(:FiniteDiff, args) do p + kernelfunction(p)(x, y) + end + end + ## Testing Kernel Matrices + A = rand(rng, dims...) + B = rand(rng, dims...) + for dim in 1:2 + @test_nowarn gradient(:FiniteDiff, A) do a + testfunction(k, a, dim) + end + @test_nowarn gradient(:FiniteDiff , A) do a + testfunction(k, a, B, dim) + end + @test_nowarn gradient(:FiniteDiff, B) do b + testfunction(k, A, b, dim) + end + if !(args === nothing) + @test_nowarn gradient(:FiniteDiff, args) do p + testfunction(kernelfunction(p), A, B, dim) + end + end + + @test_nowarn gradient(:FiniteDiff, A) do a + testdiagfunction(k, a, dim) + end + @test_nowarn gradient(:FiniteDiff , A) do a + testdiagfunction(k, a, B, dim) + end + @test_nowarn gradient(:FiniteDiff, B) do b + testdiagfunction(k, A, b, dim) + end + if !(args === nothing) + @test_nowarn gradient(:FiniteDiff, args) do p + testdiagfunction(kernelfunction(p), A, B, dim) + end + end + end + end +end + +function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) + @testset "$(AD)" begin + # Test kappa function + k = if args === nothing + kernelfunction() + else + kernelfunction(args) + end + rng = MersenneTwister(42) + if k isa SimpleKernel + for d in log.([eps(), rand(rng)]) + compare_gradient(AD, [d]) do x + kappa(k, exp(x[1])) + end + end + end + # Testing kernel evaluations + x = rand(rng, dims[1]) + y = rand(rng, dims[1]) + compare_gradient(AD, x) do x + k(x, y) + end + compare_gradient(AD, y) do y + k(x, y) + end + if !(args === nothing) + compare_gradient(AD, args) do p + kernelfunction(p)(x,y) + end + end + # Testing kernel matrices + A = rand(rng, dims...) + B = rand(rng, dims...) + for dim in 1:2 + compare_gradient(AD, A) do a + testfunction(k, a, dim) + end + compare_gradient(AD, A) do a + testfunction(k, a, B, dim) + end + compare_gradient(AD, B) do b + testfunction(k, A, b, dim) + end + if !(args === nothing) + compare_gradient(AD, args) do p + testfunction(kernelfunction(p), A, dim) + end + end + + compare_gradient(AD, A) do a + testdiagfunction(k, a, dim) + end + compare_gradient(AD, A) do a + testdiagfunction(k, a, B, dim) + end + compare_gradient(AD, B) do b + testdiagfunction(k, A, b, dim) + end + if !(args === nothing) + compare_gradient(AD, args) do p + testdiagfunction(kernelfunction(p), A, dim) + end + end + end + end +end From 2f85ebc1e0abc2f9c3d22375bb43ed15436a79fd Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Mon, 14 Dec 2020 18:05:39 +0100 Subject: [PATCH 05/48] Put changes to the right file and removed utils_AD.jl --- test/test_utils.jl | 32 ++++++++ test/utils_AD.jl | 180 --------------------------------------------- 2 files changed, 32 insertions(+), 180 deletions(-) delete mode 100644 test/utils_AD.jl diff --git a/test/test_utils.jl b/test/test_utils.jl index a94f1c54b..19ac4eeb4 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -50,6 +50,8 @@ end testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim)) testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim)) +testdiagfunction(k, A, dim) = sum(kerneldiagmatrix(k, A, obsdim = dim)) +testdiagfunction(k, A, B, dim) = sum(kerneldiagmatrix(k, A, B, obsdim = dim)) function test_ADs(kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3]) test_fd = test_FiniteDiff(kernelfunction, args, dims) @@ -105,6 +107,21 @@ function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3]) testfunction(kernelfunction(p), A, B, dim) end end + + @test_nowarn gradient(:FiniteDiff, A) do a + testdiagfunction(k, a, dim) + end + @test_nowarn gradient(:FiniteDiff , A) do a + testdiagfunction(k, a, B, dim) + end + @test_nowarn gradient(:FiniteDiff, B) do b + testdiagfunction(k, A, b, dim) + end + if !(args === nothing) + @test_nowarn gradient(:FiniteDiff, args) do p + testdiagfunction(kernelfunction(p), A, B, dim) + end + end end end end @@ -157,6 +174,21 @@ function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) testfunction(kernelfunction(p), A, dim) end end + + compare_gradient(AD, A) do a + testdiagfunction(k, a, dim) + end + compare_gradient(AD, A) do a + testdiagfunction(k, a, B, dim) + end + compare_gradient(AD, B) do b + testdiagfunction(k, A, b, dim) + end + if !(args === nothing) + compare_gradient(AD, args) do p + testdiagfunction(kernelfunction(p), A, dim) + end + end end end end diff --git a/test/utils_AD.jl b/test/utils_AD.jl deleted file mode 100644 index c86e8e218..000000000 --- a/test/utils_AD.jl +++ /dev/null @@ -1,180 +0,0 @@ - -const FDM = FiniteDifferences.central_fdm(5, 1) - -gradient(f, s::Symbol, args) = gradient(f, Val(s), args) - -function gradient(f, ::Val{:Zygote}, args) - g = first(Zygote.gradient(f, args)) - if isnothing(g) - if args isa AbstractArray{<:Real} - return zeros(size(args)) # To respect the same output as other ADs - else - return zeros.(size.(args)) - end - else - return g - end -end - -function gradient(f, ::Val{:ForwardDiff}, args) - ForwardDiff.gradient(f, args) -end - -function gradient(f, ::Val{:ReverseDiff}, args) - ReverseDiff.gradient(f, args) -end - -function gradient(f, ::Val{:FiniteDiff}, args) - first(FiniteDifferences.grad(FDM, f, args)) -end - -function compare_gradient(f, AD::Symbol, args) - grad_AD = gradient(f, AD, args) - grad_FD = gradient(f, :FiniteDiff, args) - @test grad_AD ≈ grad_FD atol=1e-8 rtol=1e-5 -end - -testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim)) -testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim)) -testdiagfunction(k, A, dim) = sum(kerneldiagmatrix(k, A, obsdim = dim)) -testdiagfunction(k, A, B, dim) = sum(kerneldiagmatrix(k, A, B, obsdim = dim)) - -function test_ADs(kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3]) - test_fd = test_FiniteDiff(kernelfunction, args, dims) - if !test_fd.anynonpass - for AD in ADs - test_AD(AD, kernelfunction, args, dims) - end - end -end - -function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3]) - # Init arguments : - k = if args === nothing - kernelfunction() - else - kernelfunction(args) - end - rng = MersenneTwister(42) - @testset "FiniteDifferences" begin - if k isa SimpleKernel - for d in log.([eps(), rand(rng)]) - @test_nowarn gradient(:FiniteDiff, [d]) do x - kappa(k, exp(first(x))) - end - end - end - ## Testing Kernel Functions - x = rand(rng, dims[1]) - y = rand(rng, dims[1]) - @test_nowarn gradient(:FiniteDiff, x) do x - k(x, y) - end - if !(args === nothing) - @test_nowarn gradient(:FiniteDiff, args) do p - kernelfunction(p)(x, y) - end - end - ## Testing Kernel Matrices - A = rand(rng, dims...) - B = rand(rng, dims...) - for dim in 1:2 - @test_nowarn gradient(:FiniteDiff, A) do a - testfunction(k, a, dim) - end - @test_nowarn gradient(:FiniteDiff , A) do a - testfunction(k, a, B, dim) - end - @test_nowarn gradient(:FiniteDiff, B) do b - testfunction(k, A, b, dim) - end - if !(args === nothing) - @test_nowarn gradient(:FiniteDiff, args) do p - testfunction(kernelfunction(p), A, B, dim) - end - end - - @test_nowarn gradient(:FiniteDiff, A) do a - testdiagfunction(k, a, dim) - end - @test_nowarn gradient(:FiniteDiff , A) do a - testdiagfunction(k, a, B, dim) - end - @test_nowarn gradient(:FiniteDiff, B) do b - testdiagfunction(k, A, b, dim) - end - if !(args === nothing) - @test_nowarn gradient(:FiniteDiff, args) do p - testdiagfunction(kernelfunction(p), A, B, dim) - end - end - end - end -end - -function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) - @testset "$(AD)" begin - # Test kappa function - k = if args === nothing - kernelfunction() - else - kernelfunction(args) - end - rng = MersenneTwister(42) - if k isa SimpleKernel - for d in log.([eps(), rand(rng)]) - compare_gradient(AD, [d]) do x - kappa(k, exp(x[1])) - end - end - end - # Testing kernel evaluations - x = rand(rng, dims[1]) - y = rand(rng, dims[1]) - compare_gradient(AD, x) do x - k(x, y) - end - compare_gradient(AD, y) do y - k(x, y) - end - if !(args === nothing) - compare_gradient(AD, args) do p - kernelfunction(p)(x,y) - end - end - # Testing kernel matrices - A = rand(rng, dims...) - B = rand(rng, dims...) - for dim in 1:2 - compare_gradient(AD, A) do a - testfunction(k, a, dim) - end - compare_gradient(AD, A) do a - testfunction(k, a, B, dim) - end - compare_gradient(AD, B) do b - testfunction(k, A, b, dim) - end - if !(args === nothing) - compare_gradient(AD, args) do p - testfunction(kernelfunction(p), A, dim) - end - end - - compare_gradient(AD, A) do a - testdiagfunction(k, a, dim) - end - compare_gradient(AD, A) do a - testdiagfunction(k, a, B, dim) - end - compare_gradient(AD, B) do b - testdiagfunction(k, A, b, dim) - end - if !(args === nothing) - compare_gradient(AD, args) do p - testdiagfunction(kernelfunction(p), A, dim) - end - end - end - end -end From cae225f0d9d17eb2def40be93639d045233b8e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 14 Dec 2020 18:29:57 +0100 Subject: [PATCH 06/48] Apply suggestions from code review Co-authored-by: David Widmann --- test/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index 19ac4eeb4..8096e7c29 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -117,7 +117,7 @@ function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3]) @test_nowarn gradient(:FiniteDiff, B) do b testdiagfunction(k, A, b, dim) end - if !(args === nothing) + if args !== nothing @test_nowarn gradient(:FiniteDiff, args) do p testdiagfunction(kernelfunction(p), A, B, dim) end @@ -184,7 +184,7 @@ function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) compare_gradient(AD, B) do b testdiagfunction(k, A, b, dim) end - if !(args === nothing) + if args !== nothing compare_gradient(AD, args) do p testdiagfunction(kernelfunction(p), A, dim) end From 3f16f076532bd2aa33fd132637b27494c6b80b12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 15 Dec 2020 14:59:10 +0100 Subject: [PATCH 07/48] Added colwise and fixed kerneldiagmatrix --- src/distances/pairwise.jl | 12 ++++++++++++ src/matrix/kernelmatrix.jl | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 188e03299..1c89a6ae7 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -39,3 +39,15 @@ function pairwise!( ) return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) end + + +# Also defines the colwise method for abstractvectors + + +function colwise(::PreMetric, x::AbstractVector) + zeros(length(x)) +end + +function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) + broadcast(d, x, y) +end \ No newline at end of file diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index 5970a5517..8b0fc059f 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -68,7 +68,7 @@ end function kerneldiagmatrix!( K::AbstractVector, κ::Kernel, x::AbstractVector, y::AbstractVector, ) - return map!(κ, x, y) + return map!(κ, K, x, y) end kerneldiagmatrix(κ::Kernel, x::AbstractVector) = map(x -> κ(x, x), x) @@ -104,11 +104,11 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) end function kerneldiagmatrix(κ::SimpleKernel, x::AbstractVector) - return map(x -> κ(x, x), x) + return kerneldiagmatrix(κ, x, x) end function kerneldiagmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) - return map(d -> kappa(κ, d), map(metric(κ), x, y)) + return map(d -> kappa(κ, d), colwise(metric(κ), x, y)) end From 8c0d0a20ab8d5452a28f8f8941fc7422d44932a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 16 Dec 2020 14:23:25 +0100 Subject: [PATCH 08/48] Added colwise for RowVecs and ColVecs --- src/distances/pairwise.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 1c89a6ae7..55de5d2d6 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -50,4 +50,12 @@ end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) broadcast(d, x, y) +end + +function colwise(d::PreMetric, x::RowVecs, y::RowVecs) + Distances.colwise(d, x.X', y.X') +end + +function colwise(d::PreMetric, x::ColVecs, y::ColVecs) + Distances.colwise(d, x.X, y.X) end \ No newline at end of file From 13a10fdd14492e22ad1572126a70048ffc96e192 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Mon, 21 Dec 2020 16:47:21 +0100 Subject: [PATCH 09/48] Removed definition relying on Distances.colwise! --- src/distances/pairwise.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 55de5d2d6..1c89a6ae7 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -50,12 +50,4 @@ end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) broadcast(d, x, y) -end - -function colwise(d::PreMetric, x::RowVecs, y::RowVecs) - Distances.colwise(d, x.X', y.X') -end - -function colwise(d::PreMetric, x::ColVecs, y::ColVecs) - Distances.colwise(d, x.X, y.X) end \ No newline at end of file From 5ca94e78b1ef458616cef146f4d24d6b82931dd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 16 Mar 2021 14:56:58 +0100 Subject: [PATCH 10/48] Readapt to kernelmatrix_diag --- src/kernels/transformedkernel.jl | 8 ++++---- src/matrix/kernelmatrix.jl | 2 +- test/test_utils.jl | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 100443c49..511d60355 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -80,8 +80,8 @@ function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::Abstrac return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x)) end -function kerneldiagmatrix!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector) - return kerneldiagmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) +function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) end function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector) @@ -98,8 +98,8 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector) return kernelmatrix_diag(κ.kernel, _map(κ.transform, x)) end -function kerneldiagmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) - return kerneldiagmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) +function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) end function kernelmatrix(κ::TransformedKernel, x::AbstractVector) diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index 3bddd616d..9f2c43d2f 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -101,7 +101,7 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector) - return kerneldiagmatrix(κ, x, x) + return κ.(x, x) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) diff --git a/test/test_utils.jl b/test/test_utils.jl index 4f81d30cb..b88d61aff 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -50,8 +50,8 @@ end testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim=dim)) testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim=dim)) -testdiagfunction(k, A, dim) = sum(kerneldiagmatrix(k, A, obsdim=dim)) -testdiagfunction(k, A, B, dim) = sum(kerneldiagmatrix(k, A, B, obsdim=dim)) +testdiagfunction(k, A, dim) = sum(kernelmatrix_diag(k, A, obsdim=dim)) +testdiagfunction(k, A, B, dim) = sum(kernelmatrix_diag(k, A, B, obsdim=dim)) function test_ADs( kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] From 2c60abd93fafb4f2203d844626238c50042b92e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 16 Mar 2021 15:48:51 +0100 Subject: [PATCH 11/48] Fixes for Zygote --- src/distances/pairwise.jl | 24 ++++++++++++++++++++---- src/matrix/kernelmatrix.jl | 4 ++-- src/zygote_adjoints.jl | 1 + 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 256bbbc5f..c58e47458 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -34,10 +34,26 @@ end # Also defines the colwise method for abstractvectors -function colwise(::PreMetric, x::AbstractVector) - zeros(length(x)) +function Distances.colwise(d::PreMetric, x::ColVecs) + Distances.colwise(d, x.X, x.X) end -function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) - broadcast(d, x, y) +function Distances.colwise(d::PreMetric, x::RowVecs) + Distances.colwise(d, x.X', x.X') +end + +function Distances.colwise(d::PreMetric, x::AbstractVector) + d.(x, x) +end + +function Distances.colwise(d::PreMetric, x::ColVecs, y::ColVecs) + Distances.colwise(d, x.X, y.X) +end + +function Distances.colwise(d::PreMetric, x::RowVecs, y::RowVecs) + Distances.colwise(d, x.X', y.X') +end + +function Distances.colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) + d.(x, y) end \ No newline at end of file diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index 9f2c43d2f..4ebf7013a 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -101,11 +101,11 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector) - return κ.(x, x) + return map(d -> kappa(κ, d), Distances.colwise(metric(κ), x)) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) - return map(d -> kappa(κ, d), colwise(metric(κ), x, y)) + return map(d -> kappa(κ, d), Distances.colwise(metric(κ), x, y)) end diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 8a9696b5d..53dc377ff 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -74,6 +74,7 @@ end RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) + @show Δ return throw(error("In slow method")) end return RowVecs(X), RowVecs_pullback From 92142114c0c28d169b6368e2cf3b5a2b6879dfec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 16 Mar 2021 16:27:10 +0100 Subject: [PATCH 12/48] Remove type piracy --- src/distances/pairwise.jl | 12 ++++++------ src/matrix/kernelmatrix.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index c58e47458..48b1a49fb 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -34,26 +34,26 @@ end # Also defines the colwise method for abstractvectors -function Distances.colwise(d::PreMetric, x::ColVecs) +function colwise(d::PreMetric, x::ColVecs) Distances.colwise(d, x.X, x.X) end -function Distances.colwise(d::PreMetric, x::RowVecs) +function colwise(d::PreMetric, x::RowVecs) Distances.colwise(d, x.X', x.X') end -function Distances.colwise(d::PreMetric, x::AbstractVector) +function colwise(d::PreMetric, x::AbstractVector) d.(x, x) end -function Distances.colwise(d::PreMetric, x::ColVecs, y::ColVecs) +function colwise(d::PreMetric, x::ColVecs, y::ColVecs) Distances.colwise(d, x.X, y.X) end -function Distances.colwise(d::PreMetric, x::RowVecs, y::RowVecs) +function colwise(d::PreMetric, x::RowVecs, y::RowVecs) Distances.colwise(d, x.X', y.X') end -function Distances.colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) +function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) d.(x, y) end \ No newline at end of file diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index 4ebf7013a..a549e4151 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -101,11 +101,11 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector) - return map(d -> kappa(κ, d), Distances.colwise(metric(κ), x)) + return map(d -> kappa(κ, d), colwise(metric(κ), x)) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) - return map(d -> kappa(κ, d), Distances.colwise(metric(κ), x, y)) + return map(d -> kappa(κ, d), colwise(metric(κ), x, y)) end From 87edbc8c664a3dd67822e8aaf32fbac0349e8519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 17 Mar 2021 15:56:40 +0100 Subject: [PATCH 13/48] Adding some adjoints (not everything fixed yet) --- src/zygote_adjoints.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 53dc377ff..3be0fdf34 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -23,6 +23,11 @@ end end end +@adjoint function Distances.colwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix) + return Distances.colwise(d, X, Y), function (Δ::AbstractVector) + return (nothing, nothing, nothing) + end +end ## Adjoints DotProduct @adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) return dot(x, y), Δ -> begin @@ -50,6 +55,12 @@ end end end +@adjoint function Distances.colwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix) + return Distances.colwise(d, X, Y), function (Δ::AbstractVector) + return (nothing, Δ .* Y, Δ .* X) + end +end + ## Adjoints Sinus @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) d = (x - y) From f65556b791df8774d05ec29c302d890641e686d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 17 Mar 2021 16:03:18 +0100 Subject: [PATCH 14/48] Fixed adjoint for polynomials --- src/zygote_adjoints.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 3be0fdf34..d668dd3e4 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -57,7 +57,7 @@ end @adjoint function Distances.colwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix) return Distances.colwise(d, X, Y), function (Δ::AbstractVector) - return (nothing, Δ .* Y, Δ .* X) + return (nothing, Δ' .* Y, Δ' .* X) end end From 48e2dcbe7c3ce1d021dba2dc495504fb0812b7ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 17 Mar 2021 17:58:13 +0100 Subject: [PATCH 15/48] Add ChainRulesCore for defining rrule --- Project.toml | 2 + src/matrix/kernelmatrix.jl | 14 +++--- src/zygote_adjoints.jl | 91 +++++++++++++++++++++++--------------- 3 files changed, 63 insertions(+), 44 deletions(-) diff --git a/Project.toml b/Project.toml index 77c557957..0e39a1f50 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" version = "0.8.24" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -17,6 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ChainRulesCore = "0.9" Compat = "3.7" Distances = "0.9.1, 0.10" Functors = "0.1" diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index a549e4151..f524fc48e 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -80,7 +80,7 @@ kernelmatrix_diag(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector) validate_inplace_dims(K, x) pairwise!(K, metric(κ), x) - return map!(d -> kappa(κ, d), K, K) + return map!(Base.Fix1(kappa, κ), K, K) end function kernelmatrix!( @@ -88,28 +88,26 @@ function kernelmatrix!( ) validate_inplace_dims(K, x, y) pairwise!(K, metric(κ), x, y) - return map!(d -> kappa(κ, d), K, K) + return map!(Base.Fix1(kappa, κ), K, K) end function kernelmatrix(κ::SimpleKernel, x::AbstractVector) - return map(d -> kappa(κ, d), pairwise(metric(κ), x)) + return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x)) end function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) validate_inputs(x, y) - return map(d -> kappa(κ, d), pairwise(metric(κ), x, y)) + return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x, y)) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector) - return map(d -> kappa(κ, d), colwise(metric(κ), x)) + return map(Base.Fix1(kappa, κ), colwise(metric(κ), x)) end function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) - return map(d -> kappa(κ, d), colwise(metric(κ), x, y)) + return map(Base.Fix1(kappa, κ), colwise(metric(κ), x, y)) end - - # # Wrapper methods for AbstractMatrix inputs to maintain obsdim interface. # diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index d668dd3e4..c07cea3f7 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -1,64 +1,82 @@ ## Adjoints Delta -@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector) - return evaluate(s, x, y), Δ -> begin - (nothing, nothing, nothing) +function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) + d = evaluate(s, x, y) + function evaluate_pullback(::Any) + return NO_FIELDS, Zero(), Zero() end + return d, evaluate_pullback end -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims=dims) - if dims == 1 - return D, Δ -> (nothing, nothing, nothing) - else - return D, Δ -> (nothing, nothing, nothing) +function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + P = Distances.pairwise(d, X, Y; dims=dims) + function pairwise_pullback(::Any) + return NO_FIELDS, Zero(), Zero() end + return P, pairwise_pullback end -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims=dims) - if dims == 1 - return D, Δ -> (nothing, nothing) - else - return D, Δ -> (nothing, nothing) +function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) + P = Distances.pairwise(d, X; dims=dims) + function pairwise_pullback(::Any) + return NO_FIELDS, Zero() end + return P, pairwise_pullback end -@adjoint function Distances.colwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix) - return Distances.colwise(d, X, Y), function (Δ::AbstractVector) - return (nothing, nothing, nothing) +function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix) + C = Distances.colwise(d, X, Y) + function colwise_pullback(::AbstractVector) + return NO_FIELDS, Zero(), Zero() end + return C, colwise_pullback end ## Adjoints DotProduct -@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) - return dot(x, y), Δ -> begin - (nothing, Δ .* y, Δ .* x) +function rrule(::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector) + d = dot(x, y) + function evaluate_pullback(Δ) + return NO_FIELDS, Δ .* y, Δ .* x end + return d, evaluate_pullback end -@adjoint function Distances.pairwise( +function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2 ) - D = Distances.pairwise(d, X, Y; dims=dims) + P = Distances.pairwise(d, X, Y; dims=dims) if dims == 1 - return D, Δ -> (nothing, Δ * Y, (X' * Δ)') + function pairwise_pullback(Δ) + return NO_FIELDS, Δ * Y, Δ' * X + end + return P, pairwise_pullback else - return D, Δ -> (nothing, (Δ * Y')', X * Δ) + function pairwise_pullback(Δ) + return NO_FIELDS, Y * Δ', X * Δ + end + return P, pairwise_pullback end end -@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims=dims) +function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) + P = Distances.pairwise(d, X; dims=dims) if dims == 1 - return D, Δ -> (nothing, 2 * Δ * X) + function pairwise_pullback(Δ) + NO_FIELDS, 2 * Δ * X + end + return P, pairwise_pullback else - return D, Δ -> (nothing, 2 * X * Δ) + function pairwise_pullback(Δ) + NO_FIELDS, 2 * X * Δ + end + return P, pairwise_pullback end end -@adjoint function Distances.colwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix) - return Distances.colwise(d, X, Y), function (Δ::AbstractVector) +function rrule(::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix) + C = Distances.colwise(d, X, Y) + function colwise_pullback(Δ::AbstractVector) return (nothing, Δ' .* Y, Δ' .* X) end + return C, colwise_pullback end ## Adjoints Sinus @@ -72,7 +90,9 @@ end end end -@adjoint function ColVecs(X::AbstractMatrix) +## Adjoints for matrix wrappers + +function rrule(::typeof(ColVecs), X::AbstractMatrix) ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) @@ -81,22 +101,21 @@ end return ColVecs(X), ColVecs_pullback end -@adjoint function RowVecs(X::AbstractMatrix) +function rrule(::typeof(RowVecs), X::AbstractMatrix) RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) - @show Δ return throw(error("In slow method")) end return RowVecs(X), RowVecs_pullback end @adjoint function Base.map(t::Transform, X::ColVecs) - return pullback(_map, t, X) + return ZygoteRules.pullback(_map, t, X) end @adjoint function Base.map(t::Transform, X::RowVecs) - return pullback(_map, t, X) + return ZygoteRules.pullback(_map, t, X) end @adjoint function (dist::Distances.SqMahalanobis)(a, b) From 6cc803d8677f8e08d7688abc88dd03a0c8c07f50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 17 Mar 2021 17:59:18 +0100 Subject: [PATCH 16/48] Replace broadcast by map --- src/distances/pairwise.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 48b1a49fb..8cb13ff20 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -43,7 +43,7 @@ function colwise(d::PreMetric, x::RowVecs) end function colwise(d::PreMetric, x::AbstractVector) - d.(x, x) + map(d, x, x) end function colwise(d::PreMetric, x::ColVecs, y::ColVecs) @@ -55,5 +55,5 @@ function colwise(d::PreMetric, x::RowVecs, y::RowVecs) end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) - d.(x, y) + map(d, x, y) end \ No newline at end of file From 0e30941594ea968f784058cd1daa1aecda09e592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 17 Mar 2021 18:27:12 +0100 Subject: [PATCH 17/48] Missing return for style --- src/distances/pairwise.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 8cb13ff20..312db6482 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -30,30 +30,28 @@ function pairwise!( return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) end - # Also defines the colwise method for abstractvectors - function colwise(d::PreMetric, x::ColVecs) - Distances.colwise(d, x.X, x.X) + return Distances.colwise(d, x.X, x.X) end function colwise(d::PreMetric, x::RowVecs) - Distances.colwise(d, x.X', x.X') + return Distances.colwise(d, x.X', x.X') end function colwise(d::PreMetric, x::AbstractVector) - map(d, x, x) + return map(d, x, x) end function colwise(d::PreMetric, x::ColVecs, y::ColVecs) - Distances.colwise(d, x.X, y.X) + return Distances.colwise(d, x.X, y.X) end function colwise(d::PreMetric, x::RowVecs, y::RowVecs) - Distances.colwise(d, x.X', y.X') + return Distances.colwise(d, x.X', y.X') end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) - map(d, x, y) + return map(d, x, y) end \ No newline at end of file From 61869b1a28dff2eb445d2241793aa0c00f49e889 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 14:35:33 +0100 Subject: [PATCH 18/48] Fixing ZygoteRules --- src/zygote_adjoints.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index c07cea3f7..6ebda3da6 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -44,30 +44,30 @@ function rrule(::typeof(Distances.pairwise), ) P = Distances.pairwise(d, X, Y; dims=dims) if dims == 1 - function pairwise_pullback(Δ) + function pairwise_pullback_cols(Δ) return NO_FIELDS, Δ * Y, Δ' * X end - return P, pairwise_pullback + return P, pairwise_pullback_cols else - function pairwise_pullback(Δ) + function pairwise_pullback_rows(Δ) return NO_FIELDS, Y * Δ', X * Δ end - return P, pairwise_pullback + return P, pairwise_pullback_rows end end function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) P = Distances.pairwise(d, X; dims=dims) if dims == 1 - function pairwise_pullback(Δ) + function pairwise_pullback_cols(Δ) NO_FIELDS, 2 * Δ * X end - return P, pairwise_pullback + return P, pairwise_pullback_cols else - function pairwise_pullback(Δ) + function pairwise_pullback_rows(Δ) NO_FIELDS, 2 * X * Δ end - return P, pairwise_pullback + return P, pairwise_pullback_rows end end @@ -80,19 +80,20 @@ function rrule(::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y: end ## Adjoints Sinus -@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) +function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) d = (x - y) sind = sinpi.(d) val = sum(abs2, sind ./ s.r) gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) - return val, Δ -> begin - ((r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx) + function evaluate_pullback(Δ) + return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx end + return val, evaluate_pullback end ## Adjoints for matrix wrappers -function rrule(::typeof(ColVecs), X::AbstractMatrix) +function rrule(::ColVecs, X::AbstractMatrix) ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) @@ -101,7 +102,7 @@ function rrule(::typeof(ColVecs), X::AbstractMatrix) return ColVecs(X), ColVecs_pullback end -function rrule(::typeof(RowVecs), X::AbstractMatrix) +function rrule(::RowVecs, X::AbstractMatrix) RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) From 06bd4f08c57c6ff1a7c92b2af361f8d6f9992f33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 17:19:42 +0100 Subject: [PATCH 19/48] Renamed zygote_adjoints to chainrules --- src/KernelFunctions.jl | 5 +-- src/{zygote_adjoints.jl => chainrules.jl} | 40 ++++++++++++----------- 2 files changed, 24 insertions(+), 21 deletions(-) rename src/{zygote_adjoints.jl => chainrules.jl} (83%) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 7a6ec8a6c..7a3079ae7 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -55,11 +55,12 @@ export IndependentMOKernel, LatentFactorMOKernel export tensor, ⊗ using Compat +using ChainRulesCore using Requires using Distances, LinearAlgebra using Functors using SpecialFunctions: loggamma, besselk, polygamma -using ZygoteRules: @adjoint, pullback +# using ZygoteRules: @adjoint, pullback, ZygoteRules using StatsFuns: logtwo using StatsBase using TensorCore @@ -112,7 +113,7 @@ include(joinpath("mokernels", "moinput.jl")) include(joinpath("mokernels", "independent.jl")) include(joinpath("mokernels", "slfm.jl")) -include("zygote_adjoints.jl") +include("chainrules.jl") include("test_utils.jl") diff --git a/src/zygote_adjoints.jl b/src/chainrules.jl similarity index 83% rename from src/zygote_adjoints.jl rename to src/chainrules.jl index 6ebda3da6..a5023aa9b 100644 --- a/src/zygote_adjoints.jl +++ b/src/chainrules.jl @@ -1,4 +1,5 @@ -## Adjoints Delta +## Reverse Rules Delta + function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) d = evaluate(s, x, y) function evaluate_pullback(::Any) @@ -30,7 +31,8 @@ function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::Abst end return C, colwise_pullback end -## Adjoints DotProduct + +## Reverse Rules DotProduct function rrule(::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector) d = dot(x, y) function evaluate_pullback(Δ) @@ -79,7 +81,7 @@ function rrule(::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y: return C, colwise_pullback end -## Adjoints Sinus +## Reverse Rules Sinus function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) d = (x - y) sind = sinpi.(d) @@ -91,7 +93,7 @@ function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::Abs return val, evaluate_pullback end -## Adjoints for matrix wrappers +## Reverse Rules for matrix wrappers function rrule(::ColVecs, X::AbstractMatrix) ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) @@ -111,20 +113,20 @@ function rrule(::RowVecs, X::AbstractMatrix) return RowVecs(X), RowVecs_pullback end -@adjoint function Base.map(t::Transform, X::ColVecs) - return ZygoteRules.pullback(_map, t, X) -end +# function rrule(::typeof(Base.map), t::Transform, X::ColVecs) +# return pullback(_map, t, X) +# end -@adjoint function Base.map(t::Transform, X::RowVecs) - return ZygoteRules.pullback(_map, t, X) -end +# function rrule(::typeof(Base.map), t::Transform, X::RowVecs) +# return pullback(_map, t, X) +# end -@adjoint function (dist::Distances.SqMahalanobis)(a, b) - function SqMahalanobis_pullback(Δ::Real) - B_Bᵀ = dist.qmat + transpose(dist.qmat) - a_b = a - b - δa = (B_Bᵀ * a_b) * Δ - return (qmat=(a_b * a_b') * Δ,), δa, -δa - end - return evaluate(dist, a, b), SqMahalanobis_pullback -end +# @adjoint function (dist::Distances.SqMahalanobis)(a, b) +# function SqMahalanobis_pullback(Δ::Real) +# B_Bᵀ = dist.qmat + transpose(dist.qmat) +# a_b = a - b +# δa = (B_Bᵀ * a_b) * Δ +# return (qmat=(a_b * a_b') * Δ,), δa, -δa +# end +# return evaluate(dist, a, b), SqMahalanobis_pullback +# end From 8e1e516b5a248e46c0a86491fecc29c6a9814a3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 17:30:43 +0100 Subject: [PATCH 20/48] Apply formatting suggestions Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 30 ++++++++++++++++++++---------- src/kernels/transformedkernel.jl | 4 +++- test/test_utils.jl | 10 +++++----- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index a5023aa9b..139d64dc0 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -8,7 +8,9 @@ function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::Abs return d, evaluate_pullback end -function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) +function rrule( + ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 +) P = Distances.pairwise(d, X, Y; dims=dims) function pairwise_pullback(::Any) return NO_FIELDS, Zero(), Zero() @@ -33,7 +35,9 @@ function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::Abst end ## Reverse Rules DotProduct -function rrule(::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector) +function rrule( + ::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector +) d = dot(x, y) function evaluate_pullback(Δ) return NO_FIELDS, Δ .* y, Δ .* x @@ -41,8 +45,12 @@ function rrule(::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y return d, evaluate_pullback end -function rrule(::typeof(Distances.pairwise), - d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2 +function rrule( + ::typeof(Distances.pairwise), + d::DotProduct, + X::AbstractMatrix, + Y::AbstractMatrix; + dims=2, ) P = Distances.pairwise(d, X, Y; dims=dims) if dims == 1 @@ -62,21 +70,23 @@ function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; d P = Distances.pairwise(d, X; dims=dims) if dims == 1 function pairwise_pullback_cols(Δ) - NO_FIELDS, 2 * Δ * X + return NO_FIELDS, 2 * Δ * X end return P, pairwise_pullback_cols else function pairwise_pullback_rows(Δ) - NO_FIELDS, 2 * X * Δ + return NO_FIELDS, 2 * X * Δ end return P, pairwise_pullback_rows end end -function rrule(::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix) - C = Distances.colwise(d, X, Y) +function rrule( + ::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix +) + C = Distances.colwise(d, X, Y) function colwise_pullback(Δ::AbstractVector) - return (nothing, Δ' .* Y, Δ' .* X) + return (nothing, Δ' .* Y, Δ' .* X) end return C, colwise_pullback end @@ -88,7 +98,7 @@ function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::Abs val = sum(abs2, sind ./ s.r) gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) function evaluate_pullback(Δ) - return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx + return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx end return val, evaluate_pullback end diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 511d60355..6cf693dca 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -80,7 +80,9 @@ function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::Abstrac return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x)) end -function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector) +function kernelmatrix_diag!( + K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector +) return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) end diff --git a/test/test_utils.jl b/test/test_utils.jl index b88d61aff..caa04b8d3 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -48,10 +48,10 @@ function compare_gradient(f, AD::Symbol, args) @test grad_AD ≈ grad_FD atol = 1e-8 rtol = 1e-5 end -testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim=dim)) -testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim=dim)) -testdiagfunction(k, A, dim) = sum(kernelmatrix_diag(k, A, obsdim=dim)) -testdiagfunction(k, A, B, dim) = sum(kernelmatrix_diag(k, A, B, obsdim=dim)) +testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B; obsdim=dim)) +testfunction(k, A, dim) = sum(kernelmatrix(k, A; obsdim=dim)) +testdiagfunction(k, A, dim) = sum(kernelmatrix_diag(k, A; obsdim=dim)) +testdiagfunction(k, A, B, dim) = sum(kernelmatrix_diag(k, A, B; obsdim=dim)) function test_ADs( kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] @@ -113,7 +113,7 @@ function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3]) @test_nowarn gradient(:FiniteDiff, A) do a testdiagfunction(k, a, dim) end - @test_nowarn gradient(:FiniteDiff , A) do a + @test_nowarn gradient(:FiniteDiff, A) do a testdiagfunction(k, a, B, dim) end @test_nowarn gradient(:FiniteDiff, B) do b From aaa16deaa945ed668e130e1adf3bdac1881017fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 18:05:22 +0100 Subject: [PATCH 21/48] Added forward rule for Euclidean distance --- src/KernelFunctions.jl | 2 +- src/chainrules.jl | 43 +++++++++++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 7a3079ae7..e9d568786 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -55,7 +55,7 @@ export IndependentMOKernel, LatentFactorMOKernel export tensor, ⊗ using Compat -using ChainRulesCore +using ChainRulesCore: ChainRulesCore, Zero, NO_FIELDS using Requires using Distances, LinearAlgebra using Functors diff --git a/src/chainrules.jl b/src/chainrules.jl index 139d64dc0..17e7d5169 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,6 +1,19 @@ +## Forward Rules + +function ChainRulesCore.frule((Δself, Δx, Δy), d::Euclidean, x::AbstractVector, y::AbstractVector) + Δ = x - y + D = norm(Δ, 1) + if iszero(D) + return D, Zero() + else + Δ ./= D + return D, dot(Δ, Δx) - dot(Δ, Δy) + end +end + ## Reverse Rules Delta -function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) +function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) d = evaluate(s, x, y) function evaluate_pullback(::Any) return NO_FIELDS, Zero(), Zero() @@ -8,7 +21,7 @@ function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::Abs return d, evaluate_pullback end -function rrule( +function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X, Y; dims=dims) @@ -18,7 +31,7 @@ function rrule( return P, pairwise_pullback end -function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) P = Distances.pairwise(d, X; dims=dims) function pairwise_pullback(::Any) return NO_FIELDS, Zero() @@ -26,7 +39,7 @@ function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 return P, pairwise_pullback end -function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix) +function ChainRulesCore.rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix) C = Distances.colwise(d, X, Y) function colwise_pullback(::AbstractVector) return NO_FIELDS, Zero(), Zero() @@ -35,7 +48,7 @@ function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::Abst end ## Reverse Rules DotProduct -function rrule( +function ChainRulesCore.rrule( ::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector ) d = dot(x, y) @@ -45,7 +58,7 @@ function rrule( return d, evaluate_pullback end -function rrule( +function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix, @@ -66,7 +79,7 @@ function rrule( end end -function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) P = Distances.pairwise(d, X; dims=dims) if dims == 1 function pairwise_pullback_cols(Δ) @@ -81,18 +94,18 @@ function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; d end end -function rrule( +function ChainRulesCore.rrule( ::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix ) C = Distances.colwise(d, X, Y) function colwise_pullback(Δ::AbstractVector) - return (nothing, Δ' .* Y, Δ' .* X) + return (NO_FIELDS, Δ' .* Y, Δ' .* X) end return C, colwise_pullback end ## Reverse Rules Sinus -function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) +function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) d = (x - y) sind = sinpi.(d) val = sum(abs2, sind ./ s.r) @@ -105,26 +118,26 @@ end ## Reverse Rules for matrix wrappers -function rrule(::ColVecs, X::AbstractMatrix) +function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) - function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) + function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return throw(error("In slow method")) end return ColVecs(X), ColVecs_pullback end -function rrule(::RowVecs, X::AbstractMatrix) +function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) - function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) + function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return throw(error("In slow method")) end return RowVecs(X), RowVecs_pullback end # function rrule(::typeof(Base.map), t::Transform, X::ColVecs) -# return pullback(_map, t, X) + # return pullback(_map, t, X) # end # function rrule(::typeof(Base.map), t::Transform, X::RowVecs) From 52b1ae5389222b4717e2f3672e34c4aeeef97934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 19:03:02 +0100 Subject: [PATCH 22/48] Corrected rules for Row/ColVecs constructors --- src/KernelFunctions.jl | 4 ++-- src/chainrules.jl | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index e9d568786..3a0e37022 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -55,13 +55,13 @@ export IndependentMOKernel, LatentFactorMOKernel export tensor, ⊗ using Compat -using ChainRulesCore: ChainRulesCore, Zero, NO_FIELDS +using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS using Requires using Distances, LinearAlgebra using Functors using SpecialFunctions: loggamma, besselk, polygamma # using ZygoteRules: @adjoint, pullback, ZygoteRules -using StatsFuns: logtwo +using StatsFuns: logtwo, twoπ using StatsBase using TensorCore diff --git a/src/chainrules.jl b/src/chainrules.jl index 17e7d5169..0fccc2e20 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,8 +1,8 @@ ## Forward Rules -function ChainRulesCore.frule((Δself, Δx, Δy), d::Euclidean, x::AbstractVector, y::AbstractVector) +function ChainRulesCore.frule((_, Δx, Δy), d::Distances.Euclidean, x, y) Δ = x - y - D = norm(Δ, 1) + D = sqrt(sum(abs2, Δ)) if iszero(D) return D, Zero() else @@ -105,11 +105,12 @@ function ChainRulesCore.rrule( end ## Reverse Rules Sinus + function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) - d = (x - y) + d = x - y sind = sinpi.(d) val = sum(abs2, sind ./ s.r) - gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) + gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) function evaluate_pullback(Δ) return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx end @@ -119,6 +120,7 @@ end ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) + ColVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) @@ -128,6 +130,7 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) end function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) + RowVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) From 4067a42062147c2fd76556769d73f9b606419bf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 19:13:25 +0100 Subject: [PATCH 23/48] Added ZygoteRules back for the "map hack" --- src/KernelFunctions.jl | 2 +- src/chainrules.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 3a0e37022..afab41ad1 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -60,7 +60,7 @@ using Requires using Distances, LinearAlgebra using Functors using SpecialFunctions: loggamma, besselk, polygamma -# using ZygoteRules: @adjoint, pullback, ZygoteRules +using ZygoteRules: ZygoteRules using StatsFuns: logtwo, twoπ using StatsBase using TensorCore diff --git a/src/chainrules.jl b/src/chainrules.jl index 0fccc2e20..78a43d1b2 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -139,13 +139,13 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) return RowVecs(X), RowVecs_pullback end -# function rrule(::typeof(Base.map), t::Transform, X::ColVecs) - # return pullback(_map, t, X) -# end +ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) + return ZygoteRules.pullback(_map, t, X) +end -# function rrule(::typeof(Base.map), t::Transform, X::RowVecs) -# return pullback(_map, t, X) -# end +ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) + return ZygoteRules.pullback(_map, t, X) +end # @adjoint function (dist::Distances.SqMahalanobis)(a, b) # function SqMahalanobis_pullback(Δ::Real) From 641ebeeee45c9c6cb2d0a6045f4d4ab87ab8a792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 19:27:07 +0100 Subject: [PATCH 24/48] Corrected the rrules --- src/chainrules.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 78a43d1b2..6c8d6d387 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -16,7 +16,7 @@ end function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) d = evaluate(s, x, y) function evaluate_pullback(::Any) - return NO_FIELDS, Zero(), Zero() + return NO_FIELDS, NO_FIELDS, Zero(), Zero() end return d, evaluate_pullback end @@ -26,7 +26,7 @@ function ChainRulesCore.rrule( ) P = Distances.pairwise(d, X, Y; dims=dims) function pairwise_pullback(::Any) - return NO_FIELDS, Zero(), Zero() + return NO_FIELDS, NO_FIELDS, Zero(), Zero() end return P, pairwise_pullback end @@ -34,7 +34,7 @@ end function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) P = Distances.pairwise(d, X; dims=dims) function pairwise_pullback(::Any) - return NO_FIELDS, Zero() + return NO_FIELDS, NO_FIELDS, Zero() end return P, pairwise_pullback end @@ -42,18 +42,19 @@ end function ChainRulesCore.rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix) C = Distances.colwise(d, X, Y) function colwise_pullback(::AbstractVector) - return NO_FIELDS, Zero(), Zero() + return NO_FIELDS, NO_FIELDS, Zero(), Zero() end return C, colwise_pullback end ## Reverse Rules DotProduct + function ChainRulesCore.rrule( ::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector ) d = dot(x, y) function evaluate_pullback(Δ) - return NO_FIELDS, Δ .* y, Δ .* x + return NO_FIELDS, NO_FIELDS, Δ .* y, Δ .* x end return d, evaluate_pullback end @@ -68,12 +69,12 @@ function ChainRulesCore.rrule( P = Distances.pairwise(d, X, Y; dims=dims) if dims == 1 function pairwise_pullback_cols(Δ) - return NO_FIELDS, Δ * Y, Δ' * X + return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X end return P, pairwise_pullback_cols else function pairwise_pullback_rows(Δ) - return NO_FIELDS, Y * Δ', X * Δ + return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ end return P, pairwise_pullback_rows end @@ -83,12 +84,12 @@ function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::Ab P = Distances.pairwise(d, X; dims=dims) if dims == 1 function pairwise_pullback_cols(Δ) - return NO_FIELDS, 2 * Δ * X + return NO_FIELDS, NO_FIELDS, 2 * Δ * X end return P, pairwise_pullback_cols else function pairwise_pullback_rows(Δ) - return NO_FIELDS, 2 * X * Δ + return NO_FIELDS, NO_FIELDS, 2 * X * Δ end return P, pairwise_pullback_rows end @@ -99,7 +100,7 @@ function ChainRulesCore.rrule( ) C = Distances.colwise(d, X, Y) function colwise_pullback(Δ::AbstractVector) - return (NO_FIELDS, Δ' .* Y, Δ' .* X) + return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X end return C, colwise_pullback end @@ -112,7 +113,7 @@ function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Sinus, x::Abstrac val = sum(abs2, sind ./ s.r) gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) function evaluate_pullback(Δ) - return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx + return NO_FIELDS, (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx end return val, evaluate_pullback end From 13d1e395aec732479e662d2e69c10f66eb13125e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 19:35:11 +0100 Subject: [PATCH 25/48] Type stable frule Co-authored-by: David Widmann --- src/chainrules.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 6c8d6d387..ea386ceea 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -3,12 +3,10 @@ function ChainRulesCore.frule((_, Δx, Δy), d::Distances.Euclidean, x, y) Δ = x - y D = sqrt(sum(abs2, Δ)) - if iszero(D) - return D, Zero() - else + if !iszero(D) Δ ./= D - return D, dot(Δ, Δx) - dot(Δ, Δy) end + return D, dot(Δ, Δx) - dot(Δ, Δy) end ## Reverse Rules Delta From 4675c2f23fe9c93257d359cdf1ff517c9597f79f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 10:27:05 +0100 Subject: [PATCH 26/48] Corrected tests --- test/chainrules.jl | 27 +++++++++++++++++++++ test/zygote_adjoints.jl | 53 ----------------------------------------- 2 files changed, 27 insertions(+), 53 deletions(-) create mode 100644 test/chainrules.jl delete mode 100644 test/zygote_adjoints.jl diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 000000000..4d732be36 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,27 @@ +@testset "Chain Rules" begin + rng = MersenneTwister(123456) + x = rand(rng, 5) + y = rand(rng, 5) + r = rand(rng, 5) + Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) + @assert isposdef(Q) + + compare_gradient(:Zygote, [x, y]) do xy + Euclidean()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + SqEuclidean()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.DotProduct()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.Delta()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.Sinus(r)(xy[1], xy[2]) + end + # compare_gradient(:Zygote, [Q, x, y]) do xy + # evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) + # end +end diff --git a/test/zygote_adjoints.jl b/test/zygote_adjoints.jl deleted file mode 100644 index 6b349437b..000000000 --- a/test/zygote_adjoints.jl +++ /dev/null @@ -1,53 +0,0 @@ -@testset "zygote_adjoints" begin - rng = MersenneTwister(123456) - x = rand(rng, 5) - y = rand(rng, 5) - r = rand(rng, 5) - Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) - @assert isposdef(Q) - - gzeucl = gradient(:Zygote, [x, y]) do xy - evaluate(Euclidean(), xy[1], xy[2]) - end - gzsqeucl = gradient(:Zygote, [x, y]) do xy - evaluate(SqEuclidean(), xy[1], xy[2]) - end - gzdotprod = gradient(:Zygote, [x, y]) do xy - evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) - end - gzdelta = gradient(:Zygote, [x, y]) do xy - evaluate(KernelFunctions.Delta(), xy[1], xy[2]) - end - gzsinus = gradient(:Zygote, [x, y]) do xy - evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) - end - gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy - evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) - end - - gfeucl = gradient(:FiniteDiff, [x, y]) do xy - evaluate(Euclidean(), xy[1], xy[2]) - end - gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy - evaluate(SqEuclidean(), xy[1], xy[2]) - end - gfdotprod = gradient(:FiniteDiff, [x, y]) do xy - evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) - end - gfdelta = gradient(:FiniteDiff, [x, y]) do xy - evaluate(KernelFunctions.Delta(), xy[1], xy[2]) - end - gfsinus = gradient(:FiniteDiff, [x, y]) do xy - evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) - end - gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy - evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) - end - - @test all(gzeucl .≈ gfeucl) - @test all(gzsqeucl .≈ gfsqeucl) - @test all(gzdotprod .≈ gfdotprod) - @test all(gzdelta .≈ gfdelta) - @test all(gzsinus .≈ gfsinus) - @test all(gzsqmaha .≈ gfsqmaha) -end From 0b97c1a3726c7698b569636b283e7ff9a2447272 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 10:27:39 +0100 Subject: [PATCH 27/48] Adapted the use of Distances.jl --- Project.toml | 2 +- src/chainrules.jl | 12 ++++++------ src/distances/delta.jl | 2 +- src/distances/dotproduct.jl | 1 - src/distances/sinus.jl | 3 +-- src/generic.jl | 2 +- 6 files changed, 10 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 0e39a1f50..99c5b60d3 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ChainRulesCore = "0.9" Compat = "3.7" -Distances = "0.9.1, 0.10" +Distances = "0.10" Functors = "0.1" Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10, 1" diff --git a/src/chainrules.jl b/src/chainrules.jl index ea386ceea..9aad22189 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -11,10 +11,10 @@ end ## Reverse Rules Delta -function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) - d = evaluate(s, x, y) +function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector) + d = dist(x, y) function evaluate_pullback(::Any) - return NO_FIELDS, NO_FIELDS, Zero(), Zero() + return NO_FIELDS, Zero(), Zero() end return d, evaluate_pullback end @@ -48,11 +48,11 @@ end ## Reverse Rules DotProduct function ChainRulesCore.rrule( - ::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector + dist::DotProduct, x::AbstractVector, y::AbstractVector ) - d = dot(x, y) + d = dist(x, y) function evaluate_pullback(Δ) - return NO_FIELDS, NO_FIELDS, Δ .* y, Δ .* x + return NO_FIELDS, Δ .* y, Δ .* x end return d, evaluate_pullback end diff --git a/src/distances/delta.jl b/src/distances/delta.jl index 273804308..b656b6804 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,4 +1,4 @@ -struct Delta <: Distances.PreMetric end +struct Delta <: Distances.UnionSemiMetric end @inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index ef0f64b28..12a88522c 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -1,5 +1,4 @@ struct DotProduct <: Distances.PreMetric end -# struct DotProduct <: Distances.UnionSemiMetric end @inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index f91884c5f..4bcf4bdf0 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -1,5 +1,4 @@ -struct Sinus{T} <: Distances.SemiMetric - # struct Sinus{T} <: Distances.UnionSemiMetric +struct Sinus{T} <: Distances.UnionSemiMetric r::Vector{T} end diff --git a/src/generic.jl b/src/generic.jl index ef8762fef..f161ca5ff 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -6,4 +6,4 @@ Base.iterate(k::Kernel, ::Any) = nothing printshifted(io::IO, o, shift::Int) = print(io, o) # Fallback implementation of evaluate for `SimpleKernel`s. -(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y)) +(k::SimpleKernel)(x, y) = kappa(k, metric(k)(x, y)) From ad9838ee3a1275ac11858a5c7824975101c4b717 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 12:20:33 +0100 Subject: [PATCH 28/48] Added methods to make nn work --- src/basekernels/nn.jl | 26 ++++++++++++++++++++++++++ src/chainrules.jl | 19 ++++++++++--------- src/distances/delta.jl | 2 +- test/basekernels/nn.jl | 1 - 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index 52b9c607c..40070075d 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -51,6 +51,19 @@ function kernelmatrix(::NeuralNetworkKernel, x::ColVecs) return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end +function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs) + x_2 = vec(sum(x.X .* x.X; dims=1)) + return asin.(x_2 ./ (x_2 .+ 1)) +end + +function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) + validate_inputs(x, y) + x_2 = vec(sum(x.X .* x.X; dims=1) .+ 1) + y_2 = vec(sum(y.X .* y.X; dims=1) .+ 1) + xy = vec(sum(x.X' .* y.X'; dims=2)) + return asin.(xy ./ sqrt.(x_2 .* y_2)) +end + function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) validate_inputs(x, y) X_2 = sum(x.X .* x.X; dims=2) @@ -65,4 +78,17 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs) return asin.(XX ./ sqrt.(X_2_1 * X_2_1')) end +function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs) + x_2 = vec(sum(x.X .* x.X; dims=2)) + return asin.(x_2 ./ (x_2 .+ 1)) +end + +function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) + validate_inputs(x, y) + x_2 = vec(sum(x.X .* x.X; dims=2) .+ 1) + y_2 = vec(sum(y.X .* y.X; dims=2) .+ 1) + xy = vec(sum(x.X .* y.X; dims=2)) + return asin.(xy ./ sqrt.(x_2 .* y_2)) +end + Base.show(io::IO, ::NeuralNetworkKernel) = print(io, "Neural Network Kernel") diff --git a/src/chainrules.jl b/src/chainrules.jl index 9aad22189..f4097022f 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -146,12 +146,13 @@ ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) return ZygoteRules.pullback(_map, t, X) end -# @adjoint function (dist::Distances.SqMahalanobis)(a, b) -# function SqMahalanobis_pullback(Δ::Real) -# B_Bᵀ = dist.qmat + transpose(dist.qmat) -# a_b = a - b -# δa = (B_Bᵀ * a_b) * Δ -# return (qmat=(a_b * a_b') * Δ,), δa, -δa -# end -# return evaluate(dist, a, b), SqMahalanobis_pullback -# end +function ChainRulesCore.rrule(dist::Distances.SqMahalanobis, a, b) + d = dist(a, b) + function SqMahalanobis_pullback(Δ::Real) + B_Bᵀ = dist.qmat + transpose(dist.qmat) + a_b = a - b + δa = (B_Bᵀ * a_b) * Δ + return (qmat=(a_b * a_b') * Δ,), δa, -δa + end + return d, SqMahalanobis_pullback +end diff --git a/src/distances/delta.jl b/src/distances/delta.jl index b656b6804..7ea15d73e 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -12,7 +12,7 @@ struct Delta <: Distances.UnionSemiMetric end return a == b end -Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb) +Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool @inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) @inline (dist::Delta)(a::Number, b::Number) = a == b diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index d64312864..c9dabeb69 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -8,5 +8,4 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(NeuralNetworkKernel) - @test_broken "Zygote uncompatible with BaseKernel" end From 650dc088418357ea9e85e2129d1a85a03b6c04a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 12:33:13 +0100 Subject: [PATCH 29/48] Missing kernelmatrix_diag --- src/basekernels/gabor.jl | 2 ++ src/kernels/kernelproduct.jl | 4 ++++ src/kernels/kernelsum.jl | 4 ++++ src/kernels/kerneltensorproduct.jl | 5 +++++ src/kernels/scaledkernel.jl | 4 ++++ 5 files changed, 19 insertions(+) diff --git a/src/basekernels/gabor.jl b/src/basekernels/gabor.jl index 2796901e2..9f85d3ed5 100644 --- a/src/basekernels/gabor.jl +++ b/src/basekernels/gabor.jl @@ -72,3 +72,5 @@ function kernelmatrix(κ::GaborKernel, x::AbstractVector, y::AbstractVector) end kernelmatrix_diag(κ::GaborKernel, x::AbstractVector) = kernelmatrix_diag(κ.kernel, x) + +kernelmatrix_diag(κ::GaborKernel, x::AbstractVector, y::AbstractVector) = kernelmatrix_diag(κ.kernel, x, y) diff --git a/src/kernels/kernelproduct.jl b/src/kernels/kernelproduct.jl index ce39dde58..990b4a1bb 100644 --- a/src/kernels/kernelproduct.jl +++ b/src/kernels/kernelproduct.jl @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector) return reduce(hadamard, kernelmatrix_diag(k, x) for k in κ.kernels) end +function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector, y::AbstractVector) + return reduce(hadamard, kernelmatrix_diag(k, x, y) for k in κ.kernels) +end + function Base.show(io::IO, κ::KernelProduct) return printshifted(io, κ, 0) end diff --git a/src/kernels/kernelsum.jl b/src/kernels/kernelsum.jl index 5dd068b12..6c4c8d499 100644 --- a/src/kernels/kernelsum.jl +++ b/src/kernels/kernelsum.jl @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelSum, x::AbstractVector) return sum(kernelmatrix_diag(k, x) for k in κ.kernels) end +function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector) + return sum(kernelmatrix_diag(k, x, y) for k in κ.kernels) +end + function Base.show(io::IO, κ::KernelSum) return printshifted(io, κ, 0) end diff --git a/src/kernels/kerneltensorproduct.jl b/src/kernels/kerneltensorproduct.jl index c46e204fc..ce9c69d6c 100644 --- a/src/kernels/kerneltensorproduct.jl +++ b/src/kernels/kerneltensorproduct.jl @@ -123,6 +123,11 @@ function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector) return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x)) end +function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector) + validate_domain(k, x) + return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x), slices(y)) +end + Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0) function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct) diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 06a48c594..9cce70ac7 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -37,6 +37,10 @@ function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector) return κ.σ² .* kernelmatrix_diag(κ.kernel, x) end +function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector, y::AbstractVector) + return κ.σ² .* kernelmatrix_diag(κ.kernel, x, y) +end + function kernelmatrix!( K::AbstractMatrix, κ::ScaledKernel, x::AbstractVector, y::AbstractVector ) From 1703db10d4f78eab48993c999389ff8fe1dedc7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 13:20:18 +0100 Subject: [PATCH 30/48] Formatting suggestions Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/basekernels/gabor.jl | 4 +++- src/chainrules.jl | 20 +++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/basekernels/gabor.jl b/src/basekernels/gabor.jl index 9f85d3ed5..311afcf1b 100644 --- a/src/basekernels/gabor.jl +++ b/src/basekernels/gabor.jl @@ -73,4 +73,6 @@ end kernelmatrix_diag(κ::GaborKernel, x::AbstractVector) = kernelmatrix_diag(κ.kernel, x) -kernelmatrix_diag(κ::GaborKernel, x::AbstractVector, y::AbstractVector) = kernelmatrix_diag(κ.kernel, x, y) +function kernelmatrix_diag(κ::GaborKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag(κ.kernel, x, y) +end diff --git a/src/chainrules.jl b/src/chainrules.jl index f4097022f..84296ca21 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -29,7 +29,9 @@ function ChainRulesCore.rrule( return P, pairwise_pullback end -function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 +) P = Distances.pairwise(d, X; dims=dims) function pairwise_pullback(::Any) return NO_FIELDS, NO_FIELDS, Zero() @@ -37,7 +39,9 @@ function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::Abstrac return P, pairwise_pullback end -function ChainRulesCore.rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix) +function ChainRulesCore.rrule( + ::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix +) C = Distances.colwise(d, X, Y) function colwise_pullback(::AbstractVector) return NO_FIELDS, NO_FIELDS, Zero(), Zero() @@ -47,9 +51,7 @@ end ## Reverse Rules DotProduct -function ChainRulesCore.rrule( - dist::DotProduct, x::AbstractVector, y::AbstractVector -) +function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector) d = dist(x, y) function evaluate_pullback(Δ) return NO_FIELDS, Δ .* y, Δ .* x @@ -78,7 +80,9 @@ function ChainRulesCore.rrule( end end -function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2 +) P = Distances.pairwise(d, X; dims=dims) if dims == 1 function pairwise_pullback_cols(Δ) @@ -105,7 +109,9 @@ end ## Reverse Rules Sinus -function ChainRulesCore.rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) +function ChainRulesCore.rrule( + ::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector +) d = x - y sind = sinpi.(d) val = sum(abs2, sind ./ s.r) From e2cd167e7d6fa8334e82f81dac1194e77a053410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 13:27:57 +0100 Subject: [PATCH 31/48] Added methods for FBM --- src/basekernels/fbm.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index fad6d70f2..eb98eabbe 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -69,3 +69,14 @@ function kernelmatrix!( K .= _fbm.(_mod(x), _mod(y)', K, κ.h) return K end + +function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector) + modx = _mod(x) + modxx = colwise(SqEuclidean(), x) + return _fbm.(modx, modx, modxx, κ.h) +end + +function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector, y::AbstractVector) + modxy = colwise(SqEuclidean(), x, y) + return _fbm.(_mod(x), _mod(y), modxy, κ.h) +end \ No newline at end of file From 01ffac0faf85c33058ea3d7fd1d4c09194377641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 13:34:22 +0100 Subject: [PATCH 32/48] Last fix on Delta --- src/distances/delta.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distances/delta.jl b/src/distances/delta.jl index 7ea15d73e..61137ba6a 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,4 +1,4 @@ -struct Delta <: Distances.UnionSemiMetric end +struct Delta <: Distances.UnionPreMetric end @inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) From 9bfb6eb53945d4c5f3dbf62e599b5fc8317021ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 13:46:55 +0100 Subject: [PATCH 33/48] Potential fix for Euclidean --- src/distances/delta.jl | 1 + src/distances/dotproduct.jl | 3 ++- src/distances/pairwise.jl | 11 ++++++++--- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/distances/delta.jl b/src/distances/delta.jl index 61137ba6a..979ecc197 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,3 +1,4 @@ +# Delta is not following the PreMetric rules since d(x, x) == 1 struct Delta <: Distances.UnionPreMetric end @inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index 12a88522c..1cef13ab5 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -1,4 +1,5 @@ -struct DotProduct <: Distances.PreMetric end +## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y +struct DotProduct <: Distances.UnionPreMetric end @inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 312db6482..24e36513e 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -32,15 +32,20 @@ end # Also defines the colwise method for abstractvectors -function colwise(d::PreMetric, x::ColVecs) +function colwise(::PreMetric, x::AbstractVector) + return zeros(length(x)) # Valid since d(x,x) == 0 by definition +end + +## The following is a hack for DotProduct and Delta to still work +function colwise(d::UnionPreMetric, x::ColVecs) return Distances.colwise(d, x.X, x.X) end -function colwise(d::PreMetric, x::RowVecs) +function colwise(d::UnionPreMetric, x::RowVecs) return Distances.colwise(d, x.X', x.X') end -function colwise(d::PreMetric, x::AbstractVector) +function colwise(d::UnionPreMetric, x::AbstractVector) return map(d, x, x) end From f3fa4bca699d278eec1d8f29d1fc17953609d2e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 14:05:53 +0100 Subject: [PATCH 34/48] Missing Distances. --- src/distances/pairwise.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 24e36513e..e8faf020c 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -37,15 +37,15 @@ function colwise(::PreMetric, x::AbstractVector) end ## The following is a hack for DotProduct and Delta to still work -function colwise(d::UnionPreMetric, x::ColVecs) +function colwise(d::Distances.UnionPreMetric, x::ColVecs) return Distances.colwise(d, x.X, x.X) end -function colwise(d::UnionPreMetric, x::RowVecs) +function colwise(d::Distances.UnionPreMetric, x::RowVecs) return Distances.colwise(d, x.X', x.X') end -function colwise(d::UnionPreMetric, x::AbstractVector) +function colwise(d::Distances.UnionPreMetric, x::AbstractVector) return map(d, x, x) end From a0c2a64d4a927052a43ec161926f15639c6ffa63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 14:11:41 +0100 Subject: [PATCH 35/48] Wrong file naming --- test/chainrules.jl | 6 +++--- test/runtests.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 4d732be36..51a545ba1 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -21,7 +21,7 @@ compare_gradient(:Zygote, [x, y]) do xy KernelFunctions.Sinus(r)(xy[1], xy[2]) end - # compare_gradient(:Zygote, [Q, x, y]) do xy - # evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) - # end + compare_gradient(:Zygote, [Q, x, y]) do xy + SqMahalanobis(xy[1])(xy[2], xy[3]) + end end diff --git a/test/runtests.jl b/test/runtests.jl index 7ad679905..a4df89988 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -146,7 +146,7 @@ include("test_utils.jl") end include("generic.jl") - include("zygote_adjoints.jl") + include("chainrules.jl") @testset "doctests" begin DocMeta.setdocmeta!( From ff5a66b08c0622fea87aee6a27db88f0705d4f02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 19:03:12 +0100 Subject: [PATCH 36/48] Correct formatting --- src/basekernels/fbm.jl | 2 +- src/distances/pairwise.jl | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index eb98eabbe..7ea88e110 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -79,4 +79,4 @@ end function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector, y::AbstractVector) modxy = colwise(SqEuclidean(), x, y) return _fbm.(_mod(x), _mod(y), modxy, κ.h) -end \ No newline at end of file +end diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index e8faf020c..e7670f451 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -13,21 +13,24 @@ end pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X) function pairwise(d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise(d, reshape(x, :, 1); dims=1) + return Distances.pairwise(d, reshape(x, :, 1); dims = 1) end function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1) + return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims = 1) end function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1) + return Distances.pairwise!(out, d, reshape(x, :, 1); dims = 1) end function pairwise!( - out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real} + out::AbstractMatrix, + d::PreMetric, + x::AbstractVector{<:Real}, + y::AbstractVector{<:Real}, ) - return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) + return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims = 1) end # Also defines the colwise method for abstractvectors @@ -59,4 +62,4 @@ end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) return map(d, x, y) -end \ No newline at end of file +end From 8157b4ccd917ee171482b5ed44cf6f791f070f8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 19:11:08 +0100 Subject: [PATCH 37/48] Better error message --- src/chainrules.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 84296ca21..7a85ee31a 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -109,15 +109,13 @@ end ## Reverse Rules Sinus -function ChainRulesCore.rrule( - ::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector -) +function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y sind = sinpi.(d) val = sum(abs2, sind ./ s.r) gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) function evaluate_pullback(Δ) - return NO_FIELDS, (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx + return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx end return val, evaluate_pullback end @@ -129,7 +127,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) - return throw(error("In slow method")) + return error( + "Pullback on AbstractVector{<:AbstractVector}.\n" * + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * + "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", + ) end return ColVecs(X), ColVecs_pullback end @@ -139,7 +141,11 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) - return throw(error("In slow method")) + return error( + "Pullback on AbstractVector{<:AbstractVector}.\n" * + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * + "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", + ) end return RowVecs(X), RowVecs_pullback end From e6bfdb1761a425617cc0112840ee5929cd65f9fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 23 Mar 2021 19:39:39 +0100 Subject: [PATCH 38/48] Moar formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/distances/pairwise.jl | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index e7670f451..14d36586b 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -13,24 +13,21 @@ end pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X) function pairwise(d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise(d, reshape(x, :, 1); dims = 1) + return Distances.pairwise(d, reshape(x, :, 1); dims=1) end function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims = 1) + return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1) end function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise!(out, d, reshape(x, :, 1); dims = 1) + return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1) end function pairwise!( - out::AbstractMatrix, - d::PreMetric, - x::AbstractVector{<:Real}, - y::AbstractVector{<:Real}, + out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real} ) - return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims = 1) + return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) end # Also defines the colwise method for abstractvectors From db5e7b8e131ba801397fac73198638eccb9a96b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 24 Mar 2021 16:43:03 +0100 Subject: [PATCH 39/48] Applied suggestions --- src/KernelFunctions.jl | 1 + src/chainrules.jl | 42 +++++++++++++++------------------------ src/distances/pairwise.jl | 4 ++-- src/zygoterules.jl | 7 +++++++ test/Project.toml | 1 + test/runtests.jl | 2 ++ test/zygoterules.jl | 3 +++ 7 files changed, 32 insertions(+), 28 deletions(-) create mode 100644 src/zygoterules.jl create mode 100644 test/zygoterules.jl diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index afab41ad1..d270deaee 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -114,6 +114,7 @@ include(joinpath("mokernels", "independent.jl")) include(joinpath("mokernels", "slfm.jl")) include("chainrules.jl") +include("zygoterules.jl") include("test_utils.jl") diff --git a/src/chainrules.jl b/src/chainrules.jl index 7a85ee31a..d9e6a17e1 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,6 +1,7 @@ ## Forward Rules -function ChainRulesCore.frule((_, Δx, Δy), d::Distances.Euclidean, x, y) +# Note that this is type piracy as the derivative should be NaN for x == y. +function ChainRulesCore.frule((_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector) Δ = x - y D = sqrt(sum(abs2, Δ)) if !iszero(D) @@ -23,7 +24,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X, Y; dims=dims) - function pairwise_pullback(::Any) + function pairwise_pullback(::AbstractMatrix) return NO_FIELDS, NO_FIELDS, Zero(), Zero() end return P, pairwise_pullback @@ -33,7 +34,7 @@ function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X; dims=dims) - function pairwise_pullback(::Any) + function pairwise_pullback(::AbstractMatrix) return NO_FIELDS, NO_FIELDS, Zero() end return P, pairwise_pullback @@ -53,7 +54,7 @@ end function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector) d = dist(x, y) - function evaluate_pullback(Δ) + function evaluate_pullback(Δ::Any) return NO_FIELDS, Δ .* y, Δ .* x end return d, evaluate_pullback @@ -68,12 +69,12 @@ function ChainRulesCore.rrule( ) P = Distances.pairwise(d, X, Y; dims=dims) if dims == 1 - function pairwise_pullback_cols(Δ) + function pairwise_pullback_cols(Δ::AbstractMatrix) return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X end return P, pairwise_pullback_cols else - function pairwise_pullback_rows(Δ) + function pairwise_pullback_rows(Δ::AbstractMatrix) return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ end return P, pairwise_pullback_rows @@ -85,12 +86,12 @@ function ChainRulesCore.rrule( ) P = Distances.pairwise(d, X; dims=dims) if dims == 1 - function pairwise_pullback_cols(Δ) + function pairwise_pullback_cols(Δ::AbstractMatrix) return NO_FIELDS, NO_FIELDS, 2 * Δ * X end return P, pairwise_pullback_cols else - function pairwise_pullback_rows(Δ) + function pairwise_pullback_rows(Δ::AbstractMatrix) return NO_FIELDS, NO_FIELDS, 2 * X * Δ end return P, pairwise_pullback_rows @@ -112,10 +113,11 @@ end function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y sind = sinpi.(d) - val = sum(abs2, sind ./ s.r) + abs2_sind_r = abs2.(sind) ./ s.r + val = sum(abs2_sind_r) gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) - function evaluate_pullback(Δ) - return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx + function evaluate_pullback(Δ::Any) + return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx end return val, evaluate_pullback end @@ -123,9 +125,7 @@ end ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) - ColVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) - ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) - ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) + ColVecs_pullback(Δ::Composite{<:ColVecs}) = (NO_FIELDS, Δ.X) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * @@ -137,9 +137,7 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) end function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) - RowVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) - RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) - RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) + RowVecs_pullback(Δ::Composite{<:RowVecs}) = (NO_FIELDS, Δ.X) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * @@ -150,20 +148,12 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) return RowVecs(X), RowVecs_pullback end -ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) - return ZygoteRules.pullback(_map, t, X) -end - -ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) - return ZygoteRules.pullback(_map, t, X) -end - function ChainRulesCore.rrule(dist::Distances.SqMahalanobis, a, b) d = dist(a, b) function SqMahalanobis_pullback(Δ::Real) B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = a - b - δa = (B_Bᵀ * a_b) * Δ + δa = @thunk((B_Bᵀ * a_b) * Δ) return (qmat=(a_b * a_b') * Δ,), δa, -δa end return d, SqMahalanobis_pullback diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 14d36586b..7379f0cbd 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -32,8 +32,8 @@ end # Also defines the colwise method for abstractvectors -function colwise(::PreMetric, x::AbstractVector) - return zeros(length(x)) # Valid since d(x,x) == 0 by definition +function colwise(d::PreMetric, x::AbstractVector) + return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition end ## The following is a hack for DotProduct and Delta to still work diff --git a/src/zygoterules.jl b/src/zygoterules.jl new file mode 100644 index 000000000..3dbd99a3a --- /dev/null +++ b/src/zygoterules.jl @@ -0,0 +1,7 @@ +ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) + return ZygoteRules.pullback(_map, t, X) +end + +ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) + return ZygoteRules.pullback(_map, t, X) +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 43bd03769..1f276008f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" diff --git a/test/runtests.jl b/test/runtests.jl index a4df89988..9ca50b15a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using Zygote: Zygote using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences +using ChainRulesTestUtils using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils @@ -147,6 +148,7 @@ include("test_utils.jl") include("generic.jl") include("chainrules.jl") + include("zygoterules.jl") @testset "doctests" begin DocMeta.setdocmeta!( diff --git a/test/zygoterules.jl b/test/zygoterules.jl new file mode 100644 index 000000000..e85be71b1 --- /dev/null +++ b/test/zygoterules.jl @@ -0,0 +1,3 @@ +@testset "zygoterules" begin + +end \ No newline at end of file From a44a7622a4cbc8f610a48c9132510a08a88056b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 24 Mar 2021 16:50:01 +0100 Subject: [PATCH 40/48] Fixed the dims issue with pairwise --- src/chainrules.jl | 45 +++++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index d9e6a17e1..e351b29d3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -68,34 +68,28 @@ function ChainRulesCore.rrule( dims=2, ) P = Distances.pairwise(d, X, Y; dims=dims) - if dims == 1 - function pairwise_pullback_cols(Δ::AbstractMatrix) + function pairwise_pullback_cols(Δ::AbstractMatrix) + if dims == 1 return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X - end - return P, pairwise_pullback_cols - else - function pairwise_pullback_rows(Δ::AbstractMatrix) + else return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ end - return P, pairwise_pullback_rows end + return P, pairwise_pullback_cols end function ChainRulesCore.rrule( ::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2 ) P = Distances.pairwise(d, X; dims=dims) - if dims == 1 - function pairwise_pullback_cols(Δ::AbstractMatrix) + function pairwise_pullback_cols(Δ::AbstractMatrix) + if dims == 1 return NO_FIELDS, NO_FIELDS, 2 * Δ * X - end - return P, pairwise_pullback_cols - else - function pairwise_pullback_rows(Δ::AbstractMatrix) + else return NO_FIELDS, NO_FIELDS, 2 * X * Δ end - return P, pairwise_pullback_rows end + return P, pairwise_pullback_cols end function ChainRulesCore.rrule( @@ -122,6 +116,19 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) return val, evaluate_pullback end +## Reverse Rulse SqMahalanobis + +function ChainRulesCore.rrule(dist::Distances.SqMahalanobis, a, b) + d = dist(a, b) + function SqMahalanobis_pullback(Δ::Real) + B_Bᵀ = dist.qmat + transpose(dist.qmat) + a_b = a - b + δa = @thunk((B_Bᵀ * a_b) * Δ) + return (qmat=(a_b * a_b') * Δ,), δa, -δa + end + return d, SqMahalanobis_pullback +end + ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) @@ -148,13 +155,3 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) return RowVecs(X), RowVecs_pullback end -function ChainRulesCore.rrule(dist::Distances.SqMahalanobis, a, b) - d = dist(a, b) - function SqMahalanobis_pullback(Δ::Real) - B_Bᵀ = dist.qmat + transpose(dist.qmat) - a_b = a - b - δa = @thunk((B_Bᵀ * a_b) * Δ) - return (qmat=(a_b * a_b') * Δ,), δa, -δa - end - return d, SqMahalanobis_pullback -end From 72889ddc4f88a14f26aeb37f6aabbe6253d2c71c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 24 Mar 2021 16:51:33 +0100 Subject: [PATCH 41/48] Fixed formatting --- src/chainrules.jl | 8 ++++++-- src/zygoterules.jl | 2 +- test/zygoterules.jl | 1 - 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index e351b29d3..5edb06e23 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,7 +1,9 @@ ## Forward Rules # Note that this is type piracy as the derivative should be NaN for x == y. -function ChainRulesCore.frule((_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector) +function ChainRulesCore.frule( + (_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector +) Δ = x - y D = sqrt(sum(abs2, Δ)) if !iszero(D) @@ -118,7 +120,9 @@ end ## Reverse Rulse SqMahalanobis -function ChainRulesCore.rrule(dist::Distances.SqMahalanobis, a, b) +function ChainRulesCore.rrule( + dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector +) d = dist(a, b) function SqMahalanobis_pullback(Δ::Real) B_Bᵀ = dist.qmat + transpose(dist.qmat) diff --git a/src/zygoterules.jl b/src/zygoterules.jl index 3dbd99a3a..88016613d 100644 --- a/src/zygoterules.jl +++ b/src/zygoterules.jl @@ -4,4 +4,4 @@ end ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) return ZygoteRules.pullback(_map, t, X) -end \ No newline at end of file +end diff --git a/test/zygoterules.jl b/test/zygoterules.jl index e85be71b1..a8073f39b 100644 --- a/test/zygoterules.jl +++ b/test/zygoterules.jl @@ -1,3 +1,2 @@ @testset "zygoterules" begin - end \ No newline at end of file From 25549c1452b6e816e354745a2b0a1f0dcc1f7a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 24 Mar 2021 17:01:13 +0100 Subject: [PATCH 42/48] Missing @thunk --- src/KernelFunctions.jl | 2 +- src/chainrules.jl | 1 - test/zygoterules.jl | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index d270deaee..b2911f769 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -55,7 +55,7 @@ export IndependentMOKernel, LatentFactorMOKernel export tensor, ⊗ using Compat -using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS +using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS, @thunk using Requires using Distances, LinearAlgebra using Functors diff --git a/src/chainrules.jl b/src/chainrules.jl index 5edb06e23..19cbc80ad 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -158,4 +158,3 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) end return RowVecs(X), RowVecs_pullback end - diff --git a/test/zygoterules.jl b/test/zygoterules.jl index a8073f39b..15f113547 100644 --- a/test/zygoterules.jl +++ b/test/zygoterules.jl @@ -1,2 +1 @@ -@testset "zygoterules" begin -end \ No newline at end of file +@testset "zygoterules" begin end \ No newline at end of file From bbe5c7c58e04212b3552f02caae08facea0b5b1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 24 Mar 2021 17:52:30 +0100 Subject: [PATCH 43/48] Putting back Composite to Any --- src/chainrules.jl | 4 ++-- test/zygoterules.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 19cbc80ad..eec28bb8f 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -136,7 +136,7 @@ end ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) - ColVecs_pullback(Δ::Composite{<:ColVecs}) = (NO_FIELDS, Δ.X) + ColVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * @@ -148,7 +148,7 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) end function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) - RowVecs_pullback(Δ::Composite{<:RowVecs}) = (NO_FIELDS, Δ.X) + RowVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * diff --git a/test/zygoterules.jl b/test/zygoterules.jl index 15f113547..dc3bb98fe 100644 --- a/test/zygoterules.jl +++ b/test/zygoterules.jl @@ -1 +1 @@ -@testset "zygoterules" begin end \ No newline at end of file +@testset "zygoterules" begin end From e08dbf41c64ea41ac7438033563c73a39c0c6387 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 24 Mar 2021 18:40:51 +0100 Subject: [PATCH 44/48] add @thunk for -delta a --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index eec28bb8f..1af1ef6fc 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -128,7 +128,7 @@ function ChainRulesCore.rrule( B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = a - b δa = @thunk((B_Bᵀ * a_b) * Δ) - return (qmat=(a_b * a_b') * Δ,), δa, -δa + return (qmat=(a_b * a_b') * Δ,), δa, @thunk(-δa) end return d, SqMahalanobis_pullback end From 48bd6815703803be0d0c7fefe9a307df1b549165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Mar 2021 10:58:53 +0100 Subject: [PATCH 45/48] Update src/chainrules.jl Co-authored-by: David Widmann --- src/chainrules.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 1af1ef6fc..de7e5cf71 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -125,10 +125,20 @@ function ChainRulesCore.rrule( ) d = dist(a, b) function SqMahalanobis_pullback(Δ::Real) - B_Bᵀ = dist.qmat + transpose(dist.qmat) a_b = a - b - δa = @thunk((B_Bᵀ * a_b) * Δ) - return (qmat=(a_b * a_b') * Δ,), δa, @thunk(-δa) + ∂qmat = InplaceableThunk( + @thunk((a_b * a_b') * Δ), + X̄ -> mul!(X̄, a_b, a_b', true, Δ), + ) + ∂a = InplaceableThunk( + @thunk((2 * Δ) * dist.qmat * a_b), + X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), + ) + ∂b = InplaceableThunk( + @thunk((-2 * Δ) * dist.qmat * a_b), + X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), + ) + return Composite{typeof(dist)}(qmat = ∂qmat), ∂a, ∂b end return d, SqMahalanobis_pullback end From 3298d34353ba9fb2a17584549f0f6804162795e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Mar 2021 11:39:17 +0100 Subject: [PATCH 46/48] Update KernelFunctions.jl --- src/KernelFunctions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index b2911f769..8746b70e4 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -55,7 +55,8 @@ export IndependentMOKernel, LatentFactorMOKernel export tensor, ⊗ using Compat -using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS, @thunk +using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS +using ChainRulesCore: @thunk, InplaceableThunk using Requires using Distances, LinearAlgebra using Functors From 0b99771abc7ae5e44c9433531196d380b6439b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Mar 2021 11:40:03 +0100 Subject: [PATCH 47/48] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index de7e5cf71..3c871d89a 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -127,18 +127,15 @@ function ChainRulesCore.rrule( function SqMahalanobis_pullback(Δ::Real) a_b = a - b ∂qmat = InplaceableThunk( - @thunk((a_b * a_b') * Δ), - X̄ -> mul!(X̄, a_b, a_b', true, Δ), + @thunk((a_b * a_b') * Δ), X̄ -> mul!(X̄, a_b, a_b', true, Δ) ) ∂a = InplaceableThunk( - @thunk((2 * Δ) * dist.qmat * a_b), - X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), + @thunk((2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ) ) ∂b = InplaceableThunk( - @thunk((-2 * Δ) * dist.qmat * a_b), - X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), + @thunk((-2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ) ) - return Composite{typeof(dist)}(qmat = ∂qmat), ∂a, ∂b + return Composite{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b end return d, SqMahalanobis_pullback end From c26edf3adc8b79f995bf7a892388232801adecde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Mar 2021 14:33:00 +0100 Subject: [PATCH 48/48] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 99c5b60d3..542e490ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.8.24" +version = "0.8.26" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"