|
4 | 4 | end
|
5 | 5 | end
|
6 | 6 |
|
| 7 | +@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) |
| 8 | + D = pairwise(d, X, Y; dims = dims) |
| 9 | + if dims == 1 |
| 10 | + return D, Δ -> (nothing, Δ * Y, (X' * Δ)') |
| 11 | + else |
| 12 | + return D, Δ -> (nothing, (Δ * Y')', X * Δ) |
| 13 | + end |
| 14 | +end |
| 15 | + |
| 16 | +@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2) |
| 17 | + D = pairwise(d, X; dims = dims) |
| 18 | + if dims == 1 |
| 19 | + return D, Δ -> (nothing, 2 * Δ * X) |
| 20 | + else |
| 21 | + return D, Δ -> (nothing, 2 * X * Δ) |
| 22 | + end |
| 23 | +end |
| 24 | + |
| 25 | +@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) |
| 26 | + d = (x - y) |
| 27 | + sind = sinpi.(d) |
| 28 | + val = sum(abs2, sind ./ s.r) |
| 29 | + gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) |
| 30 | + val, Δ -> begin |
| 31 | + ((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) |
| 32 | + end |
| 33 | +end |
| 34 | + |
| 35 | +@adjoint function pairwise(s::Sinus, X::AbstractMatrix, Y::AbstractMatrix; dims=2) |
| 36 | + D = pairwise(d, X, Y; dims = dims) |
| 37 | + throw(error("Sinus metric has no defined adjoint for now... PR welcome!")) |
| 38 | +end |
| 39 | + |
| 40 | +@adjoint function pairwise(s::Sinus, X::AbstractMatrix; dims=2) |
| 41 | + D = pairwise(d, X; dims = dims) |
| 42 | + throw(error("Sinus metric has no defined adjoint for now... PR welcome!")) |
| 43 | +end |
| 44 | + |
7 | 45 | @adjoint function loggamma(x)
|
8 | 46 | first(logabsgamma(x)) , Δ -> (Δ .* polygamma(0, x), )
|
9 | 47 | end
|
|
36 | 74 | return RowVecs(X), back
|
37 | 75 | end
|
38 | 76 |
|
39 |
| -# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) |
40 |
| -# d = evaluate(s, x, y) |
41 |
| -# s = sum(sin.(π*(x-y))) |
42 |
| -# d, Δ -> begin |
43 |
| -# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d) |
44 |
| -# end |
45 |
| -# end |
| 77 | +@adjoint function Base.map(t::Transform, X::ColVecs) |
| 78 | + pullback(_map, t, X) |
| 79 | +end |
| 80 | + |
| 81 | +@adjoint function Base.map(t::Transform, X::RowVecs) |
| 82 | + pullback(_map, t, X) |
| 83 | +end |
0 commit comments