Skip to content

Commit 6875aee

Browse files
committed
Created full adjoints for DotProduct and evaluate for Sinus
1 parent d88dcff commit 6875aee

File tree

1 file changed

+45
-7
lines changed

1 file changed

+45
-7
lines changed

src/zygote_adjoints.jl

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,44 @@
44
end
55
end
66

7+
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
8+
D = pairwise(d, X, Y; dims = dims)
9+
if dims == 1
10+
return D, Δ -> (nothing, Δ * Y, (X' * Δ)')
11+
else
12+
return D, Δ -> (nothing, (Δ * Y')', X * Δ)
13+
end
14+
end
15+
16+
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
17+
D = pairwise(d, X; dims = dims)
18+
if dims == 1
19+
return D, Δ -> (nothing, 2 * Δ * X)
20+
else
21+
return D, Δ -> (nothing, 2 * X * Δ)
22+
end
23+
end
24+
25+
@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
26+
d = (x - y)
27+
sind = sinpi.(d)
28+
val = sum(abs2, sind ./ s.r)
29+
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2)
30+
val, Δ -> begin
31+
((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx)
32+
end
33+
end
34+
35+
@adjoint function pairwise(s::Sinus, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
36+
D = pairwise(d, X, Y; dims = dims)
37+
throw(error("Sinus metric has no defined adjoint for now... PR welcome!"))
38+
end
39+
40+
@adjoint function pairwise(s::Sinus, X::AbstractMatrix; dims=2)
41+
D = pairwise(d, X; dims = dims)
42+
throw(error("Sinus metric has no defined adjoint for now... PR welcome!"))
43+
end
44+
745
@adjoint function loggamma(x)
846
first(logabsgamma(x)) , Δ ->.* polygamma(0, x), )
947
end
@@ -36,10 +74,10 @@ end
3674
return RowVecs(X), back
3775
end
3876

39-
# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
40-
# d = evaluate(s, x, y)
41-
# s = sum(sin.(π*(x-y)))
42-
# d, Δ -> begin
43-
# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d)
44-
# end
45-
# end
77+
@adjoint function Base.map(t::Transform, X::ColVecs)
78+
pullback(_map, t, X)
79+
end
80+
81+
@adjoint function Base.map(t::Transform, X::RowVecs)
82+
pullback(_map, t, X)
83+
end

0 commit comments

Comments
 (0)