|
1 | 1 | Base.@propagate_inbounds function rcopyto_at!( |
2 | 2 | pair::Pair{<:AbstractData, <:Any}, |
3 | | - I, |
| 3 | + cart_inds, |
| 4 | + tidx, |
4 | 5 | us, |
5 | 6 | ) |
6 | 7 | dest, bc = pair.first, pair.second |
7 | | - if is_valid_index(dest, I, us) |
| 8 | + if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds)) |
| 9 | + I = unval(cart_inds)[tidx] |
8 | 10 | dest[I] = isascalar(bc) ? bc[] : bc[I] |
9 | 11 | end |
10 | 12 | return nothing |
11 | 13 | end |
12 | | -Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, I, us) |
| 14 | +Base.@propagate_inbounds function rcopyto_at!( |
| 15 | + pair::Pair{<:DataF, <:Any}, |
| 16 | + cart_inds, |
| 17 | + tidx, |
| 18 | + us, |
| 19 | +) |
13 | 20 | dest, bc = pair.first, pair.second |
14 | | - if is_valid_index(dest, I, us) |
| 21 | + if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds)) |
| 22 | + I = unval(cart_inds)[tidx] |
15 | 23 | bcI = isascalar(bc) ? bc[] : bc[I] |
16 | 24 | dest[] = bcI |
17 | 25 | end |
18 | 26 | return nothing |
19 | 27 | end |
20 | | -Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, us) |
21 | | - rcopyto_at!(first(pairs), I, us) |
22 | | - rcopyto_at!(Base.tail(pairs), I, us) |
| 28 | +Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, cart_inds, tidx, us) |
| 29 | + rcopyto_at!(first(pairs), cart_inds, tidx, us) |
| 30 | + rcopyto_at!(Base.tail(pairs), cart_inds, tidx, us) |
23 | 31 | end |
24 | | -Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, us) = |
25 | | - rcopyto_at!(first(pairs), I, us) |
26 | | -@inline rcopyto_at!(pairs::Tuple{}, I, us) = nothing |
| 32 | +Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, cart_inds, tidx, us) = |
| 33 | + rcopyto_at!(first(pairs), cart_inds, tidx, us) |
| 34 | +@inline rcopyto_at!(pairs::Tuple{}, cart_inds, tidx, us) = nothing |
27 | 35 |
|
28 | | -function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us) |
| 36 | +function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us, cart_inds) |
29 | 37 | @inbounds begin |
30 | | - I = universal_index(dest1) |
31 | | - if is_valid_index(dest1, I, us) |
| 38 | + tidx = linear_thread_idx() |
| 39 | + if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds)) |
32 | 40 | (; pairs) = fmbc |
33 | | - rcopyto_at!(pairs, I, us) |
| 41 | + rcopyto_at!(pairs, cart_inds, tidx, us) |
34 | 42 | end |
35 | 43 | end |
36 | 44 | return nothing |
@@ -138,10 +146,11 @@ function launch_fused_copyto!(fmb::FusedMultiBroadcast) |
138 | 146 | blocks_s = p.blocks, |
139 | 147 | ) |
140 | 148 | else |
141 | | - args = (fmb, dest1, us) |
| 149 | + cart_inds = cartesian_indices(us) |
| 150 | + args = (fmb, dest1, us, cart_inds) |
142 | 151 | threads = threads_via_occupancy(knl_fused_copyto!, args) |
143 | 152 | n_max_threads = min(threads, get_N(us)) |
144 | | - p = partition(dest1, n_max_threads) |
| 153 | + p = linear_partition(prod(size(dest1)), n_max_threads) |
145 | 154 | auto_launch!( |
146 | 155 | knl_fused_copyto!, |
147 | 156 | args; |
|
0 commit comments