Skip to content

Commit c556832

Browse files
authored
Specialize broadcast to avoid integer divisions. (#304)
By using hardware 2d/3d indices whenever possible, and recompiling kernels for common broadcast shapes.
1 parent 0a0a8b4 commit c556832

File tree

2 files changed

+137
-1
lines changed

2 files changed

+137
-1
lines changed

src/broadcast.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,139 @@ BroadcastStyle(::MtlArrayStyle{N, S1},
1919
# allocation of output arrays
2020
Base.similar(bc::Broadcasted{MtlArrayStyle{N,S}}, ::Type{T}, dims) where {T,N,S} =
2121
similar(MtlArray{T,length(dims),S}, dims)
22+
23+
# a static version of CartesianIndices that helps avoiding integer division
24+
# (at the expense of additional compilation)
25+
struct StaticCartesianIndices{N, I} end
26+
StaticCartesianIndices(iter::CartesianIndices{N}) where {N} =
27+
StaticCartesianIndices{N, iter.indices}()
28+
StaticCartesianIndices(x) = StaticCartesianIndices(CartesianIndices(x))
29+
Base.CartesianIndices(iter::StaticCartesianIndices{N, I}) where {N, I} =
30+
CartesianIndices{N, typeof(I)}(I)
31+
Base.@propagate_inbounds Base.getindex(I::StaticCartesianIndices, i::Integer) =
32+
CartesianIndices(I)[Int(i)]
33+
Base.length(I::StaticCartesianIndices) = length(CartesianIndices(I))
34+
function Base.show(io::IO, I::StaticCartesianIndices)
35+
print(io, "Static")
36+
show(io, CartesianIndices(I))
37+
end
38+
39+
# specialization of the broadcast implementation to avoid expensive integer divisions
40+
const _broadcast_shapes = Base.IdDict()
41+
const BROADCAST_SPECIALIZATION_THRESHOLD = 10
42+
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:MtlArrayStyle}
43+
return _copyto!(dest, Broadcast.instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
44+
end
45+
@inline Base.copyto!(dest::MtlArray, bc::Broadcasted{Nothing}) =
46+
_copyto!(dest, bc) # Keep it for ArrayConflict
47+
@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:MtlArrayStyle}) =
48+
_copyto!(dest, bc)
49+
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
50+
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
51+
isempty(dest) && return dest
52+
bc = Broadcast.preprocess(dest, bc)
53+
54+
# if this is a common broadcast shape, specialize the kernel on it
55+
Is = CartesianIndices(dest)
56+
if !haskey(_broadcast_shapes, Is)
57+
_broadcast_shapes[Is] = 1
58+
else
59+
_broadcast_shapes[Is] += 1
60+
end
61+
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
62+
function broadcast_cartesian_static(dest, bc, Is)
63+
i = thread_position_in_grid_1d()
64+
stride = threads_per_grid_1d()
65+
while 1 <= i <= length(dest)
66+
I = @inbounds Is[i]
67+
@inbounds dest[I] = bc[I]
68+
i += stride
69+
end
70+
return
71+
end
72+
73+
Is = StaticCartesianIndices(Is)
74+
kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is)
75+
elements = cld(length(dest), 4)
76+
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
77+
groups = cld(elements, threads)
78+
kernel(dest, bc, Is; threads, groups)
79+
return dest
80+
end
81+
82+
# try to use the most appropriate hardware index to avoid integer division
83+
if ndims(dest) == 1 ||
84+
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
85+
function broadcast_linear(dest, bc)
86+
i = thread_position_in_grid_1d()
87+
stride = threads_per_grid_1d()
88+
while 1 <= i <= length(dest)
89+
@inbounds dest[i] = bc[i]
90+
i += stride
91+
end
92+
return
93+
end
94+
95+
kernel = @metal launch=false broadcast_linear(dest, bc)
96+
elements = cld(length(dest), 4)
97+
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
98+
groups = cld(elements, threads)
99+
elseif ndims(dest) == 2
100+
function broadcast_2d(dest, bc)
101+
is = Tuple(thread_position_in_grid_2d())
102+
stride = threads_per_grid_2d()
103+
while 1 <= is[1] <= size(dest, 1) && 1 <= is[2] <= size(dest, 2)
104+
I = CartesianIndex(is)
105+
@inbounds dest[I] = bc[I]
106+
is = (is[1] + stride[1], is[2] + stride[2])
107+
end
108+
return
109+
end
110+
111+
kernel = @metal launch=false broadcast_2d(dest, bc)
112+
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
113+
h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
114+
threads = (w, h)
115+
groups = cld.(size(dest), threads)
116+
elseif ndims(dest) == 3
117+
function broadcast_3d(dest, bc)
118+
is = Tuple(thread_position_in_grid_3d())
119+
stride = threads_per_grid_3d()
120+
while 1 <= is[1] <= size(dest, 1) &&
121+
1 <= is[2] <= size(dest, 2) &&
122+
1 <= is[3] <= size(dest, 3)
123+
I = CartesianIndex(is)
124+
@inbounds dest[I] = bc[I]
125+
is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3])
126+
end
127+
return
128+
end
129+
130+
kernel = @metal launch=false broadcast_3d(dest, bc)
131+
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
132+
h = min(size(dest, 2), kernel.pipeline.threadExecutionWidth,
133+
kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
134+
d = min(size(dest, 3), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ (w*h))
135+
threads = (w, h, d)
136+
groups = cld.(size(dest), threads)
137+
else
138+
function broadcast_cartesian(dest, bc)
139+
i = thread_position_in_grid_1d()
140+
stride = threads_per_grid_1d()
141+
while 1 <= i <= length(dest)
142+
I = @inbounds CartesianIndices(dest)[i]
143+
@inbounds dest[I] = bc[I]
144+
i += stride
145+
end
146+
return
147+
end
148+
149+
kernel = @metal launch=false broadcast_cartesian(dest, bc)
150+
elements = cld(length(dest), 4)
151+
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
152+
groups = cld(elements, threads)
153+
end
154+
kernel(dest, bc; threads, groups)
155+
156+
return dest
157+
end

test/metal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ end
456456
buf = Base.unsafe_convert(MTL.MTLBuffer, arr)
457457
Metal.unsafe_fill!(current_device(), Metal.MtlPointer{T}(buf, 0), T(val), 4)
458458

459-
@test all(arr .== val)
459+
@test all(Array(arr) .== val)
460460
end
461461
end
462462

0 commit comments

Comments
 (0)