Skip to content

Commit d4e854f

Browse files
gather(x, IJK...)
1 parent 0b64dc1 commit d4e854f

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

src/gather.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ or multiple `dst` columns.
5454
See [`gather!`](@ref) for an in-place version.
5555
5656
# Examples
57+
5758
```jldoctest
5859
julia> NNlib.gather([1,20,300,4000], [2,4,2])
5960
3-element Vector{Int64}:
@@ -83,5 +84,36 @@ function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::A
8384
y = gather!(dst, src, idx)
8485
src_size = size(src)
8586
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
86-
y, gather!_pullback
87+
return y, gather!_pullback
8788
end
89+
90+
"""
91+
gather(src, IJK...)
92+
93+
Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and
94+
call `gather` on it: `gather(src, CartesianIndex.(IJK...))`.
95+
96+
# Examples
97+
98+
```jldoctest
99+
julia> src = reshape([1:15;], 3, 5)
100+
3×5 Matrix{Int64}:
101+
1 4 7 10 13
102+
2 5 8 11 14
103+
3 6 9 12 15
104+
105+
julia> gather(src, [1, 2], [2, 4])
106+
2-element Vector{Int64}:
107+
4
108+
11
109+
```
110+
"""
111+
function gather(src::AbstractArray{Tsrc, Nsrc},
112+
IJK::AbstractVector{<:Integer}...) where {Nsrc, Tsrc}
113+
114+
return gather(src, to_cartesian_index(IJK))
115+
end
116+
117+
to_cartesian_index(IJK) = CartesianIndex.(IJK...)
118+
119+
@non_differentiable to_cartesian_index(idx::AbstractVector{<:Integer}...)

test/gather.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,14 @@ end
149149
gradtest(xs -> gather!(dst, xs, index), src)
150150
gradtest(xs -> gather(xs, index), src)
151151
end
152+
153+
@testset "gather(src, IJK...)" begin
154+
x = reshape([1:15;], 3, 5)
155+
156+
y = gather(x, [1,2], [2,4])
157+
@test y == [4, 11]
158+
159+
@test gather(x, [1, 2]) == [1 4
160+
2 5
161+
3 6]
162+
end

0 commit comments

Comments
 (0)