@@ -63,7 +63,7 @@ julia> augment_batch_dim(A, 3)
63
63
```
64
64
"""
65
65
function augment_batch_dim (input:: AbstractArray{T,N} , n) where {T,N}
66
- return repeat (input; inner= (ntuple (_ -> 1 , Val ( N - 1 ) )... , n))
66
+ return repeat (input; inner= (ntuple (Returns ( 1 ), N - 1 )... , n))
67
67
end
68
68
69
69
"""
72
72
Reduce augmented input batch by averaging the explanation for each augmented sample.
73
73
"""
74
74
function reduce_augmentation (input:: AbstractArray{T,N} , n) where {T<: AbstractFloat ,N}
75
- return cat (
76
- (
77
- Iterators. map (1 : n: size (input, N)) do i
78
- augmentation_range = ntuple (_ -> :, Val (N - 1 ))... , i: (i + n - 1 )
79
- sum (view (input, augmentation_range... ); dims= N) / n
80
- end
81
- ). .. ; dims= N
82
- ):: Array{T,N}
75
+ # Allocate output array
76
+ in_size = size (input)
77
+ in_size[end ] % n != 0 &&
78
+ throw (ArgumentError (" Can't reduce augmented batch size of $(in_size[end ]) by $n " ))
79
+ out_size = (in_size[1 : (end - 1 )]. .. , div (in_size[end ], n))
80
+ out = similar (input, eltype (input), out_size)
81
+
82
+ axs = axes (input, N)
83
+ inds_before_N = ntuple (Returns (:), N - 1 )
84
+ for (i, ax) in enumerate (first (axs): n: last (axs))
85
+ view (out, inds_before_N... , i) .=
86
+ sum (view (input, inds_before_N... , ax: (ax + n - 1 )); dims= N) / n
87
+ end
88
+ return out
83
89
end
84
90
"""
85
91
augment_indices(indices, n)
0 commit comments