Skip to content

Commit f1b89ab

Browse files
authored
Remove use of mapreduce (#58)
* Update `gradients_wrt_batch` * Update `reduce_augmentation`
1 parent 4ec4908 commit f1b89ab

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

src/gradient.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ end
44

55
function gradients_wrt_batch(model, input::AbstractArray{T,N}, output_indices) where {T,N}
66
# To avoid computing a sparse jacobian, we compute individual gradients
7-
# by mapping `gradient_wrt_input` on slices of the input along the batch dimension.
8-
return mapreduce(
9-
(gs...) -> cat(gs...; dims=N), zip(eachslice(input; dims=N), output_indices)
10-
) do (in, idx)
11-
gradient_wrt_input(model, batch_dim_view(in), drop_batch_index(idx))
7+
# by calling `gradient_wrt_input` on slices of the input along the batch dimension.
8+
out = similar(input)
9+
inds_before_N = ntuple(Returns(:), N - 1)
10+
for (i, ax) in enumerate(axes(input, N))
11+
view(out, inds_before_N..., ax, :) .= gradient_wrt_input(
12+
model, view(input, inds_before_N..., ax, :), drop_batch_index(output_indices[i])
13+
)
1214
end
15+
return out
1316
end
1417

1518
"""

src/input_augmentation.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ julia> augment_batch_dim(A, 3)
6363
```
6464
"""
6565
function augment_batch_dim(input::AbstractArray{T,N}, n) where {T,N}
66-
return repeat(input; inner=(ntuple(_ -> 1, Val(N - 1))..., n))
66+
return repeat(input; inner=(ntuple(Returns(1), N - 1)..., n))
6767
end
6868

6969
"""
@@ -72,14 +72,20 @@ end
7272
Reduce augmented input batch by averaging the explanation for each augmented sample.
7373
"""
7474
function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFloat,N}
75-
return cat(
76-
(
77-
Iterators.map(1:n:size(input, N)) do i
78-
augmentation_range = ntuple(_ -> :, Val(N - 1))..., i:(i + n - 1)
79-
sum(view(input, augmentation_range...); dims=N) / n
80-
end
81-
)...; dims=N
82-
)::Array{T,N}
75+
# Allocate output array
76+
in_size = size(input)
77+
in_size[end] % n != 0 &&
78+
throw(ArgumentError("Can't reduce augmented batch size of $(in_size[end]) by $n"))
79+
out_size = (in_size[1:(end - 1)]..., div(in_size[end], n))
80+
out = similar(input, eltype(input), out_size)
81+
82+
axs = axes(input, N)
83+
inds_before_N = ntuple(Returns(:), N - 1)
84+
for (i, ax) in enumerate(first(axs):n:last(axs))
85+
view(out, inds_before_N..., i) .=
86+
sum(view(input, inds_before_N..., ax:(ax + n - 1)); dims=N) / n
87+
end
88+
return out
8389
end
8490
"""
8591
augment_indices(indices, n)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ julia> batch_dim_view(A)
3939
3 4
4040
```
4141
"""
42-
batch_dim_view(A::AbstractArray{T,N}) where {T,N} = view(A, ntuple(_ -> :, Val(N + 1))...)
42+
batch_dim_view(A::AbstractArray{T,N}) where {T,N} = view(A, ntuple(Returns(:), N + 1)...)
4343

4444
"""
4545
drop_batch_index(I)

0 commit comments

Comments
 (0)