Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.56"
version = "0.10.57"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
10 changes: 6 additions & 4 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" *
"or because some external computation has acted on `ColVecs` to produce a vector of vectors." *
"If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
)
end
return ColVecs(X), ColVecs_pullback
Expand All @@ -162,8 +163,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" *
"or because some external computation has acted on `RowVecs` to produce a vector of vectors." *
"If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
)
end
return RowVecs(X), RowVecs_pullback
Expand Down