@@ -215,7 +215,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
215215 end
216216
217217 XLA. await (a. data)
218- if XLA . BufferOnCPU (a . data . buffer )
218+ if buffer_on_cpu (a )
219219 buf = a. data. buffer
220220 GC. @preserve buf begin
221221 ptr = Base. unsafe_convert (Ptr{T}, XLA. UnsafeBufferPointer (buf))
@@ -246,7 +246,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
246246 end
247247
248248 XLA. await (a. data)
249- if XLA . BufferOnCPU (a . data . buffer )
249+ if buffer_on_cpu (a )
250250 buf = a. data. buffer
251251 GC. @preserve buf begin
252252 ptr = Base. unsafe_convert (Ptr{T}, XLA. UnsafeBufferPointer (buf))
@@ -289,15 +289,52 @@ end
289289
290290# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
291291function Base. copy (bc:: Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}} )
292- ElType = Base. Broadcast. combine_eltypes (bc. f, bc. args)
293- if ! Base. isconcretetype (ElType)
294- throw (
295- ErrorException (
296- " `copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
297- ),
298- )
292+ for x in bc. args
293+ x isa ConcreteRArray && XLA. await (x. data)
299294 end
300295
301- aux = copyto! (similar (Array{ElType}, axes (bc)), bc)
302- return ConcreteRArray (aux)
296+ all_on_cpu = all (buffer_on_cpu, bc. args)
297+ if all_on_cpu
298+ ElType = Base. Broadcast. combine_eltypes (bc. f, bc. args)
299+ if ! Base. isconcretetype (ElType)
300+ throw (
301+ ErrorException (
302+ " `copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
303+ ),
304+ )
305+ end
306+ aux = copyto! (similar (Array{ElType}, axes (bc)), bc)
307+ return ConcreteRArray (aux)
308+ end
309+
310+ fn = Reactant. compile (Broadcast. BroadcastFunction (bc. f), (bc. args... ,))
311+ return fn (bc. args... )
312+ end
313+
314+ function Base. copyto! (dest:: ConcreteRArray , src:: ConcreteRArray )
315+ dest. data = src. data
316+ return dest
317+ end
318+
319+ function Base. mapreduce (
320+ @nospecialize (f),
321+ @nospecialize (op),
322+ @nospecialize (A:: ConcreteRArray{T,N} );
323+ dims= :,
324+ init= nothing ,
325+ ) where {T,N}
326+ fn = Reactant. compile (CallMapReduce (f, op, dims, init), (A,))
327+ return fn (A)
328+ end
329+
330+ struct CallMapReduce{Fn,Op,Dims,Init}
331+ f:: Fn
332+ op:: Op
333+ dims:: Dims
334+ init:: Init
303335end
336+
337+ (f:: CallMapReduce )(A) = Base. mapreduce (f. f, f. op, A; f. dims, f. init)
338+
339+ buffer_on_cpu (:: Any ) = true
340+ buffer_on_cpu (x:: ConcreteRArray ) = XLA. BufferOnCPU (x. data. buffer)
0 commit comments