Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.6.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -14,6 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "0.9"
Compat = "2.2, 3"
Distances = "0.9"
Requires = "1.0.1"
Expand Down
5 changes: 3 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ using Compat
using Requires
using Distances, LinearAlgebra
using SpecialFunctions: loggamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
using ChainRulesCore
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using StatsBase


"""
Abstract type defining a slice-wise transformation on an input matrix
"""
Expand Down Expand Up @@ -74,7 +75,7 @@ include("generic.jl")
include("mokernels/moinput.jl")
include("mokernels/independent.jl")

include("zygote_adjoints.jl")
include("chainrules.jl")

function __init__()
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
Expand Down
90 changes: 90 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
## rules for Delta
function ChainRulesCore.rrule(::typeof(evaluate), s::Delta, x::AbstractVector, y::AbstractVector)
evaluate(s, x, y), Δ -> begin
(NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist())
end
end

function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D = Distances.pairwise(d, X, Y; dims = dims)
if dims == 1
return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist())
else
return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist())
end
end

function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2)
D = Distances.pairwise(d, X; dims = dims)
if dims == 1
return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist())
else
return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist())
end
end

## rules for DotProduct
function ChainRulesCore.rrule(::typeof(evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector)
dot(x, y), Δ -> begin
(NO_FIELDS, nothing, Δ .* y, Δ .* x)
end
end

function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D = Distances.pairwise(d, X, Y; dims = dims)
if dims == 1
return D, Δ -> (NO_FIELDS, nothing, Δ * Y, (X' * Δ)')
else
return D, Δ -> (NO_FIELDS, nothing, (Δ * Y')', X * Δ)
end
end

function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2)
D = Distances.pairwise(d, X; dims = dims)
if dims == 1
return D, Δ -> (NO_FIELDS, nothing, 2 * Δ * X)
else
return D, Δ -> (NO_FIELDS, nothing, 2 * X * Δ)
end
end

## rules for Sinus
function ChainRulesCore.rrule(::typeof(evaluate), s::Sinus, x::AbstractVector, y::AbstractVector)
d = (x - y)
sind = sinpi.(d)
val = sum(abs2, sind ./ s.r)
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2)
val, Δ -> begin
(NO_FIELDS, (r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx)
end
end


# rules for ColVecs and RowVecs
function ChainRulesCore.rrule(::typeof(ColVecs), X::AbstractMatrix)
back(Δ::NamedTuple) = (NO_FIELDS, Δ.X,)
back(Δ::AbstractMatrix) = (NO_FIELDS, Δ,)
function back(Δ::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return ColVecs(X), back
end

function ChainRulesCore.rrule(::typeof(RowVecs), X::AbstractMatrix)
back(Δ::NamedTuple) = (NO_FIELDS, Δ.X,)
back(Δ::AbstractMatrix) = (NO_FIELDS, Δ,)
function back(Δ::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return RowVecs(X), back
end


# rules for transforms
function ChainRulesCore.rrule(::typeof(Base.map), t::Transform, X::ColVecs)
ChainRulesCore.rrule(_map, t, X)
end

function ChainRulesCore.rrule(::typeof(Base.map), t::Transform, X::RowVecs)
ChainRulesCore.rrule(_map, t, X)
end
86 changes: 0 additions & 86 deletions src/zygote_adjoints.jl

This file was deleted.

4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -13,6 +15,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.5"
Distances = "0.9"
FiniteDifferences = "0.10.8"
Flux = "0.10, 0.11"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Random
using SpecialFunctions
using Test
using Flux
using ChainRulesTestUtils
import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences

using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs
Expand Down