|
1 | | -@kernel inbounds=true cpu=false function _mapreduce_nd_by_thread!(@Const(src), dst, f, op, init, dims) |
2 | | - |
| 1 | +@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_thread!( |
| 2 | + @Const(src), |
| 3 | + dst, |
| 4 | + f, |
| 5 | + op, |
| 6 | + init, |
| 7 | + dims, |
| 8 | +) |
3 | 9 | # One thread per output element, when there are more outer elements than in the reduced dim |
4 | 10 | # e.g. reduce(+, rand(3, 1000), dims=1) => only 3 elements in the reduced dim |
5 | 11 | src_sizes = size(src) |
|
64 | 70 | end |
65 | 71 |
|
66 | 72 |
|
67 | | -@kernel inbounds=true cpu=false function _mapreduce_nd_by_block!(@Const(src), dst, f, op, init, neutral, dims) |
68 | | - |
| 73 | +@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_block!( |
| 74 | + @Const(src), |
| 75 | + dst, |
| 76 | + f, |
| 77 | + op, |
| 78 | + init, |
| 79 | + neutral, |
| 80 | + dims, |
| 81 | +) |
69 | 82 | # One block per output element, when there are more elements in the reduced dim than in outer |
70 | 83 | # e.g. reduce(+, rand(3, 1000), dims=2) => only 3 elements in outer dimensions |
71 | 84 | src_sizes = size(src) |
|
90 | 103 | iblock = @index(Group, Linear) - 0x1 |
91 | 104 | ithread = @index(Local, Linear) - 0x1 |
92 | 105 |
|
93 | | - # Each block handles one output element |
94 | | - if iblock < output_size |
95 | | - |
96 | | - # # Sometimes slightly faster method using additional memory with |
97 | | - # # output_idx = @private typeof(iblock) (ndims,) |
98 | | - # tmp = iblock |
99 | | - # KernelAbstractions.Extras.@unroll for i in ndims:-1:1 |
100 | | - # output_idx[i] = tmp ÷ dst_strides[i] |
101 | | - # tmp = tmp % dst_strides[i] |
102 | | - # end |
103 | | - # # Compute the base index in src (excluding the reduced axis) |
104 | | - # input_base_idx = 0 |
105 | | - # KernelAbstractions.Extras.@unroll for i in 1:ndims |
106 | | - # i == dims && continue |
107 | | - # input_base_idx += output_idx[i] * src_strides[i] |
108 | | - # end |
109 | | - |
110 | | - # Compute the base index in src (excluding the reduced axis) |
111 | | - input_base_idx = typeof(ithread)(0) |
112 | | - tmp = iblock |
113 | | - KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16 |
114 | | - if i != dims |
115 | | - input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i] |
116 | | - end |
117 | | - tmp = tmp % dst_strides[i] |
| 106 | + # Each block handles one output element - thus, iblock ∈ [0, output_size) |
| 107 | + |
| 108 | + # # Sometimes slightly faster method using additional memory with |
| 109 | + # # output_idx = @private typeof(iblock) (ndims,) |
| 110 | + # tmp = iblock |
| 111 | + # KernelAbstractions.Extras.@unroll for i in ndims:-1:1 |
| 112 | + # output_idx[i] = tmp ÷ dst_strides[i] |
| 113 | + # tmp = tmp % dst_strides[i] |
| 114 | + # end |
| 115 | + # # Compute the base index in src (excluding the reduced axis) |
| 116 | + # input_base_idx = 0 |
| 117 | + # KernelAbstractions.Extras.@unroll for i in 1:ndims |
| 118 | + # i == dims && continue |
| 119 | + # input_base_idx += output_idx[i] * src_strides[i] |
| 120 | + # end |
| 121 | + |
| 122 | + # Compute the base index in src (excluding the reduced axis) |
| 123 | + input_base_idx = typeof(ithread)(0) |
| 124 | + tmp = iblock |
| 125 | + KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16 |
| 126 | + if i != dims |
| 127 | + input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i] |
118 | 128 | end |
| 129 | + tmp = tmp % dst_strides[i] |
| 130 | + end |
119 | 131 |
|
120 | | - # We have a block of threads to process the whole reduced dimension. First do pre-reduction |
121 | | - # in strides of N |
122 | | - partial = neutral |
123 | | - i = ithread |
124 | | - while i < reduce_size |
125 | | - src_idx = input_base_idx + i * src_strides[dims] |
126 | | - partial = op(partial, f(src[src_idx + 0x1])) |
127 | | - i += N |
128 | | - end |
| 132 | + # We have a block of threads to process the whole reduced dimension. First do pre-reduction |
| 133 | + # in strides of N |
| 134 | + partial = neutral |
| 135 | + i = ithread |
| 136 | + while i < reduce_size |
| 137 | + src_idx = input_base_idx + i * src_strides[dims] |
| 138 | + partial = op(partial, f(src[src_idx + 0x1])) |
| 139 | + i += N |
| 140 | + end |
129 | 141 |
|
130 | | - # Store partial result in shared memory; now we are down to a single block to reduce within |
131 | | - sdata[ithread + 0x1] = partial |
132 | | - @synchronize() |
| 142 | + # Store partial result in shared memory; now we are down to a single block to reduce within |
| 143 | + sdata[ithread + 0x1] = partial |
| 144 | + @synchronize() |
133 | 145 |
|
134 | | - if N >= 512u16 |
135 | | - ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])) |
136 | | - @synchronize() |
137 | | - end |
138 | | - if N >= 256u16 |
139 | | - ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])) |
140 | | - @synchronize() |
141 | | - end |
142 | | - if N >= 128u16 |
143 | | - ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])) |
144 | | - @synchronize() |
145 | | - end |
146 | | - if N >= 64u16 |
147 | | - ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])) |
148 | | - @synchronize() |
149 | | - end |
150 | | - if N >= 32u16 |
151 | | - ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])) |
152 | | - @synchronize() |
153 | | - end |
154 | | - if N >= 16u16 |
155 | | - ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])) |
156 | | - @synchronize() |
157 | | - end |
158 | | - if N >= 8u16 |
159 | | - ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])) |
160 | | - @synchronize() |
161 | | - end |
162 | | - if N >= 4u16 |
163 | | - ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])) |
164 | | - @synchronize() |
165 | | - end |
166 | | - if N >= 2u16 |
167 | | - ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])) |
168 | | - @synchronize() |
169 | | - end |
170 | | - if ithread == 0x0 |
171 | | - dst[iblock + 0x1] = op(init, sdata[0x1]) |
172 | | - end |
| 146 | + if N >= 512u16 |
| 147 | + ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])) |
| 148 | + @synchronize() |
| 149 | + end |
| 150 | + if N >= 256u16 |
| 151 | + ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])) |
| 152 | + @synchronize() |
| 153 | + end |
| 154 | + if N >= 128u16 |
| 155 | + ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])) |
| 156 | + @synchronize() |
| 157 | + end |
| 158 | + if N >= 64u16 |
| 159 | + ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])) |
| 160 | + @synchronize() |
| 161 | + end |
| 162 | + if N >= 32u16 |
| 163 | + ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])) |
| 164 | + @synchronize() |
| 165 | + end |
| 166 | + if N >= 16u16 |
| 167 | + ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])) |
| 168 | + @synchronize() |
| 169 | + end |
| 170 | + if N >= 8u16 |
| 171 | + ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])) |
| 172 | + @synchronize() |
| 173 | + end |
| 174 | + if N >= 4u16 |
| 175 | + ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])) |
| 176 | + @synchronize() |
| 177 | + end |
| 178 | + if N >= 2u16 |
| 179 | + ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])) |
| 180 | + @synchronize() |
| 181 | + end |
| 182 | + if ithread == 0x0 |
| 183 | + dst[iblock + 0x1] = op(init, sdata[0x1]) |
173 | 184 | end |
174 | 185 | end |
175 | 186 |
|
|
0 commit comments