@@ -19,3 +19,139 @@ BroadcastStyle(::MtlArrayStyle{N, S1},
1919# allocation of output arrays
2020Base. similar (bc:: Broadcasted{MtlArrayStyle{N,S}} , :: Type{T} , dims) where {T,N,S} =
2121 similar (MtlArray{T,length (dims),S}, dims)
22+
23+ # a static version of CartesianIndices that helps avoiding integer division
24+ # (at the expense of additional compilation)
25+ struct StaticCartesianIndices{N, I} end
26+ StaticCartesianIndices (iter:: CartesianIndices{N} ) where {N} =
27+ StaticCartesianIndices {N, iter.indices} ()
28+ StaticCartesianIndices (x) = StaticCartesianIndices (CartesianIndices (x))
29+ Base. CartesianIndices (iter:: StaticCartesianIndices{N, I} ) where {N, I} =
30+ CartesianIndices {N, typeof(I)} (I)
31+ Base. @propagate_inbounds Base. getindex (I:: StaticCartesianIndices , i:: Integer ) =
32+ CartesianIndices (I)[Int (i)]
33+ Base. length (I:: StaticCartesianIndices ) = length (CartesianIndices (I))
34+ function Base. show (io:: IO , I:: StaticCartesianIndices )
35+ print (io, " Static" )
36+ show (io, CartesianIndices (I))
37+ end
38+
39+ # specialization of the broadcast implementation to avoid expensive integer divisions
40+ const _broadcast_shapes = Base. IdDict ()
41+ const BROADCAST_SPECIALIZATION_THRESHOLD = 10
42+ @inline function Base. materialize! (:: Style , dest, bc:: Broadcasted ) where {Style<: MtlArrayStyle }
43+ return _copyto! (dest, Broadcast. instantiate (Broadcasted {Style} (bc. f, bc. args, axes (dest))))
44+ end
45+ @inline Base. copyto! (dest:: MtlArray , bc:: Broadcasted{Nothing} ) =
46+ _copyto! (dest, bc) # Keep it for ArrayConflict
47+ @inline Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:MtlArrayStyle} ) =
48+ _copyto! (dest, bc)
49+ @inline function _copyto! (dest:: AbstractArray , bc:: Broadcasted )
50+ axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
51+ isempty (dest) && return dest
52+ bc = Broadcast. preprocess (dest, bc)
53+
54+ # if this is a common broadcast shape, specialize the kernel on it
55+ Is = CartesianIndices (dest)
56+ if ! haskey (_broadcast_shapes, Is)
57+ _broadcast_shapes[Is] = 1
58+ else
59+ _broadcast_shapes[Is] += 1
60+ end
61+ if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
62+ function broadcast_cartesian_static (dest, bc, Is)
63+ i = thread_position_in_grid_1d ()
64+ stride = threads_per_grid_1d ()
65+ while 1 <= i <= length (dest)
66+ I = @inbounds Is[i]
67+ @inbounds dest[I] = bc[I]
68+ i += stride
69+ end
70+ return
71+ end
72+
73+ Is = StaticCartesianIndices (Is)
74+ kernel = @metal launch= false broadcast_cartesian_static (dest, bc, Is)
75+ elements = cld (length (dest), 4 )
76+ threads = min (elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
77+ groups = cld (elements, threads)
78+ kernel (dest, bc, Is; threads, groups)
79+ return dest
80+ end
81+
82+ # try to use the most appropriate hardware index to avoid integer division
83+ if ndims (dest) == 1 ||
84+ (isa (IndexStyle (dest), IndexLinear) && isa (IndexStyle (bc), IndexLinear))
85+ function broadcast_linear (dest, bc)
86+ i = thread_position_in_grid_1d ()
87+ stride = threads_per_grid_1d ()
88+ while 1 <= i <= length (dest)
89+ @inbounds dest[i] = bc[i]
90+ i += stride
91+ end
92+ return
93+ end
94+
95+ kernel = @metal launch= false broadcast_linear (dest, bc)
96+ elements = cld (length (dest), 4 )
97+ threads = min (elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
98+ groups = cld (elements, threads)
99+ elseif ndims (dest) == 2
100+ function broadcast_2d (dest, bc)
101+ is = Tuple (thread_position_in_grid_2d ())
102+ stride = threads_per_grid_2d ()
103+ while 1 <= is[1 ] <= size (dest, 1 ) && 1 <= is[2 ] <= size (dest, 2 )
104+ I = CartesianIndex (is)
105+ @inbounds dest[I] = bc[I]
106+ is = (is[1 ] + stride[1 ], is[2 ] + stride[2 ])
107+ end
108+ return
109+ end
110+
111+ kernel = @metal launch= false broadcast_2d (dest, bc)
112+ w = min (size (dest, 1 ), kernel. pipeline. threadExecutionWidth)
113+ h = min (size (dest, 2 ), kernel. pipeline. maxTotalThreadsPerThreadgroup ÷ w)
114+ threads = (w, h)
115+ groups = cld .(size (dest), threads)
116+ elseif ndims (dest) == 3
117+ function broadcast_3d (dest, bc)
118+ is = Tuple (thread_position_in_grid_3d ())
119+ stride = threads_per_grid_3d ()
120+ while 1 <= is[1 ] <= size (dest, 1 ) &&
121+ 1 <= is[2 ] <= size (dest, 2 ) &&
122+ 1 <= is[3 ] <= size (dest, 3 )
123+ I = CartesianIndex (is)
124+ @inbounds dest[I] = bc[I]
125+ is = (is[1 ] + stride[1 ], is[2 ] + stride[2 ], is[3 ] + stride[3 ])
126+ end
127+ return
128+ end
129+
130+ kernel = @metal launch= false broadcast_3d (dest, bc)
131+ w = min (size (dest, 1 ), kernel. pipeline. threadExecutionWidth)
132+ h = min (size (dest, 2 ), kernel. pipeline. threadExecutionWidth,
133+ kernel. pipeline. maxTotalThreadsPerThreadgroup ÷ w)
134+ d = min (size (dest, 3 ), kernel. pipeline. maxTotalThreadsPerThreadgroup ÷ (w* h))
135+ threads = (w, h, d)
136+ groups = cld .(size (dest), threads)
137+ else
138+ function broadcast_cartesian (dest, bc)
139+ i = thread_position_in_grid_1d ()
140+ stride = threads_per_grid_1d ()
141+ while 1 <= i <= length (dest)
142+ I = @inbounds CartesianIndices (dest)[i]
143+ @inbounds dest[I] = bc[I]
144+ i += stride
145+ end
146+ return
147+ end
148+
149+ kernel = @metal launch= false broadcast_cartesian (dest, bc)
150+ elements = cld (length (dest), 4 )
151+ threads = min (elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
152+ groups = cld (elements, threads)
153+ end
154+ kernel (dest, bc; threads, groups)
155+
156+ return dest
157+ end
0 commit comments