@@ -4,15 +4,6 @@ using Base.Broadcast
4
4
5
5
import Base. Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
6
6
7
- const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
8
- Base. RefValue{<: AbstractGPUArray{T} }}
9
-
10
- # Ref is special: it's not a real wrapper, so not part of Adapt,
11
- # but it is commonly used to bypass broadcasting of an argument
12
- # so we need to preserve its dimensionless properties.
13
- BroadcastStyle (:: Type{Base.RefValue{AT}} ) where {AT<: AbstractGPUArray } =
14
- typeof (BroadcastStyle (AT))(Val (0 ))
15
- backend (:: Type{Base.RefValue{AT}} ) where {AT<: AbstractGPUArray } = backend (AT)
16
7
# but make sure we don't dispatch to the optimized copy method that directly indexes
17
8
function Broadcast. copy (bc:: Broadcasted{<:AbstractGPUArrayStyle{0}} )
18
9
ElType = Broadcast. combine_eltypes (bc. f, bc. args)
41
32
return _copyto! (dest, instantiate (Broadcasted {Style} (bc. f, bc. args, axes (dest))))
42
33
end
43
34
44
- @inline Base. copyto! (dest:: BroadcastGPUArray , bc:: Broadcasted{Nothing} ) = _copyto! (dest, bc) # Keep it for ArrayConflict
35
+ @inline Base. copyto! (dest:: AnyGPUArray , bc:: Broadcasted{Nothing} ) = _copyto! (dest, bc) # Keep it for ArrayConflict
45
36
46
37
@inline Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:AbstractGPUArrayStyle} ) = _copyto! (dest, bc)
47
38
77
68
allequal (x) = true
78
69
allequal (x, y, z... ) = x == y && allequal (y, z... )
79
70
80
- function Base. map (f, x:: BroadcastGPUArray , xs:: AbstractArray... )
71
+ function Base. map (f, x:: AnyGPUArray , xs:: AbstractArray... )
81
72
# if argument sizes match, their shape needs to be preserved
82
73
xs = (x, xs... )
83
74
if allequal (size .(xs)... )
@@ -96,7 +87,7 @@ function Base.map(f, x::BroadcastGPUArray, xs::AbstractArray...)
96
87
return map! (f, dest, xs... )
97
88
end
98
89
99
- function Base. map! (f, dest:: BroadcastGPUArray , xs:: AbstractArray... )
90
+ function Base. map! (f, dest:: AnyGPUArray , xs:: AbstractArray... )
100
91
# custom broadcast, ignoring the container size mismatches
101
92
# (avoids the reshape + view that our mapreduce impl has to do)
102
93
indices = LinearIndices .((dest, xs... ))
0 commit comments