Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
9 changes: 9 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using StatsBase
using TensorCore
using Tullio
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

# Hack to work around Zygote type inference problems.
Expand All @@ -67,7 +68,15 @@ abstract type Kernel end
abstract type SimpleKernel <: Kernel end

include("utils.jl")

const VecOfVecs = Union{ColVecs,RowVecs}

# A general binary op type not respecting Distances metric rules
abstract type AbstractBinaryOp end
const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric}

include("distances/pairwise.jl")
include("distances/euclidean.jl")
include("distances/dotproduct.jl")
include("distances/delta.jl")
include("distances/sinus.jl")
Expand Down
6 changes: 2 additions & 4 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Delta is not following the PreMetric rules since d(x, x) == 1
struct Delta <: Distances.UnionPreMetric end
struct Delta <: AbstractBinaryOp end

# Basic definitions
(dist::Delta)(a::Number, b::Number) = a == b
Base.@propagate_inbounds function (dist::Delta)(
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
Expand All @@ -14,5 +14,3 @@ Base.@propagate_inbounds function (dist::Delta)(
end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
31 changes: 16 additions & 15 deletions src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y
struct DotProduct <: Distances.UnionPreMetric end
struct DotProduct <: AbstractBinaryOp end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
throw(
DimensionMismatch(
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
),
)
end
return dot(a, b)
(::DotProduct)(a::AbstractVector, b::AbstractVector) = dot(a, b)

(::DotProduct)(a::Number, b::Number) = a * b

function pairwise(::DotProduct, x::ColVecs, y::ColVecs)
return @tullio out[i, j] := x.X[k, i] * y.X[k, j]
end

Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
function pairwise(::DotProduct, x::RowVecs, y::RowVecs)
return @tullio out[i, j] := x.X[i, k] * y.X[j, k]
end

function colwise(::DotProduct, x::RowVecs, y::RowVecs=x)
return @tullio out[i] := x.X[i, k] * y.X[i, k]
end

@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
@inline function (dist::DotProduct)(a::AbstractArray, b::AbstractArray)
return Distances._evaluate(dist, a, b)
function colwise(::DotProduct, x::ColVecs, y::ColVecs=x)
return @tullio out[i] := x.X[k, i] * y.X[k, i]
end
@inline (dist::DotProduct)(a::Number, b::Number) = a * b
19 changes: 19 additions & 0 deletions src/distances/euclidean.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Tullio specialization for Euclidean and SqEuclidean metrics

function pairwise(::Euclidean, x::ColVecs, y::ColVecs)
return @tullio out[i, j] :=
sqrt <| x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2
end

function pairwise(::Euclidean, x::RowVecs, y::RowVecs)
return @tullio out[i, j] :=
sqrt <| x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2
end

function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs)
return @tullio out[i, j] := x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2
end

function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs)
return @tullio out[i, j] := x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2
end
66 changes: 12 additions & 54 deletions src/distances/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,70 +1,28 @@
# Add our own pairwise function to be able to apply it on vectors

function pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector)
return broadcast(d, X, permutedims(Y))
function pairwise(d::BinaryOp, X::AbstractVector, Y::AbstractVector=X)
return @tullio out[i, j] := d(X[i], Y[j])
end

pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)

function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector)
return broadcast!(d, out, X, permutedims(Y))
end

pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)

function pairwise(d::PreMetric, x::AbstractVector{<:Real})
return Distances_pairwise(d, reshape(x, :, 1); dims=1)
end

function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
return Distances_pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
end

function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1)
end

function pairwise!(
out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
)
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
function pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector, Y::AbstractVector=X)
return @tullio out[i, j] = d(X[i], Y[j])
end

# Also defines the colwise method for abstractvectors

function colwise(d::PreMetric, x::AbstractVector)
# We have different methods for PreMetric and AbstractBinaryOp
# Since colwise on AbstractBinaryOp is not guaranteed to be equal to 0
function colwise(d::Distances.PreMetric, x::AbstractVector)
return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition
end

function colwise(d::PreMetric, x::ColVecs)
function colwise(d::Distances.PreMetric, x::VecOfVecs)
return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition
end

function colwise(d::PreMetric, x::RowVecs)
return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition
end

## The following is a hack for DotProduct and Delta to still work
function colwise(d::Distances.UnionPreMetric, x::ColVecs)
return Distances.colwise(d, x.X, x.X)
end

function colwise(d::Distances.UnionPreMetric, x::RowVecs)
return Distances.colwise(d, x.X', x.X')
end

function colwise(d::Distances.UnionPreMetric, x::AbstractVector)
return map(d, x, x)
end

function colwise(d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.colwise(d, x.X, y.X)
end

function colwise(d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.colwise(d, x.X', y.X')
function colwise(d::AbstractBinaryOp, x::AbstractVector)
return @tullio out[i] := d(x[i], x[i])
end

function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector)
return map(d, x, y)
function colwise(d::BinaryOp, x::AbstractVector, y::AbstractVector)
return @tullio out[i] := d(x[i], y[i])
end
4 changes: 2 additions & 2 deletions src/distances/sinus.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct Sinus{T} <: Distances.UnionSemiMetric
r::Vector{T}
struct Sinus{T,V<:AbstractVector{T}} <: Distances.SemiMetric
r::V
end

Distances.parameters(d::Sinus) = d.r
Expand Down
38 changes: 7 additions & 31 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,6 @@ Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X))

dim(x::ColVecs) = size(x.X, 1)

pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
end
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
return Distances.pairwise!(out, d, x.X; dims=2)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.pairwise!(out, d, x.X, y.X; dims=2)
end

"""
RowVecs(X::AbstractMatrix)

Expand Down Expand Up @@ -150,25 +135,16 @@ Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X))

dim(x::RowVecs) = size(x.X, 2)

pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
end
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
return Distances.pairwise!(out, d, x.X; dims=1)
# Resolve ambiguity error for ColVecs vs RowVecs. #346
pairwise(d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X)))
pairwise(d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y)
function pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::RowVecs)
return pairwise!(out, d, x, ColVecs(permutedims(y.X)))
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.pairwise!(out, d, x.X, y.X; dims=1)
function pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::ColVecs)
return pairwise!(out, d, ColVecs(permutedims(x.X)), y)
end

# Resolve ambiguity error for ColVecs vs RowVecs. #346
pairwise(d::PreMetric, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X)))
pairwise(d::PreMetric, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y)

dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
dim(x::AbstractVector) = dim(first(x))
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
Expand Down