Skip to content

Commit a8825c9

Browse files
author
Miha Zgubic
committed
add implementation
1 parent 237d5f6 commit a8825c9

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("rulesets/Base/evalpoly.jl")
3636
include("rulesets/Base/array.jl")
3737
include("rulesets/Base/arraymath.jl")
3838
include("rulesets/Base/indexing.jl")
39+
include("rulesets/Base/sort.jl")
3940
include("rulesets/Base/mapreduce.jl")
4041

4142
include("rulesets/Statistics/statistics.jl")

src/rulesets/Base/sort.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...)
2+
inds = partialsortperm(xs, k; kwargs...)
3+
ys = xs[inds]
4+
5+
function partialsort_pullback(Δys)
6+
function partialsort_add!(Δxs)
7+
for (Δy, i) in zip(Δys, inds)
8+
Δxs[i] += Δy
9+
end
10+
return Δxs
11+
end
12+
13+
Δxs = InplaceableThunk(
14+
@thunk(partialsort_add!(zero(xs))),
15+
partialsort_add!
16+
)
17+
return NO_FIELDS, Δxs, DoesNotExist()
18+
end
19+
return ys, partialsort_pullback
20+
end
21+
22+
function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
23+
inds = sortperm(xs; kwargs...)
24+
ys = xs[inds]
25+
26+
function sort_pullback(Δys)
27+
function sort_add!(Δxs)
28+
Δxs[inds] += Δys
29+
return Δxs
30+
end
31+
32+
Δxs = InplaceableThunk(
33+
@thunk(sort_add!(zero(xs))),
34+
sort_add!
35+
)
36+
37+
return NO_FIELDS, Δxs
38+
end
39+
return ys, sort_pullback
40+
end

0 commit comments

Comments
 (0)