Skip to content

Commit a9fb2d1

Browse files
committed
Simplify map dispatch.
1 parent 02b3fb8 commit a9fb2d1

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

src/host/base.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,5 @@
11
# common Base functionality
22

3-
allequal(x) = true
4-
allequal(x, y, z...) = x == y && allequal(y, z...)
5-
function Base.map!(f, y::AbstractGPUArray, xs::AbstractGPUArray...)
6-
@assert allequal(size.((y, xs...))...)
7-
return y .= f.(xs...)
8-
end
9-
function Base.map(f, y::AbstractGPUArray, xs::AbstractGPUArray...)
10-
@assert allequal(size.((y, xs...))...)
11-
return f.(y, xs...)
12-
end
13-
14-
# Break ambiguities with base
15-
Base.map!(f, y::AbstractGPUArray) =
16-
invoke(map!, Tuple{Any,AbstractGPUArray,Vararg{AbstractGPUArray}}, f, y)
17-
Base.map!(f, y::AbstractGPUArray, x::AbstractGPUArray) =
18-
invoke(map!, Tuple{Any,AbstractGPUArray, Vararg{AbstractGPUArray}}, f, y, x)
19-
Base.map!(f, y::AbstractGPUArray, x1::AbstractGPUArray, x2::AbstractGPUArray) =
20-
invoke(map!, Tuple{Any,AbstractGPUArray, Vararg{AbstractGPUArray}}, f, y, x1, x2)
21-
223
function Base.repeat(a::AbstractGPUVecOrMat, m::Int, n::Int = 1)
234
o, p = size(a, 1), size(a, 2)
245
b = similar(a, o*m, p*n)

src/host/broadcast.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,19 @@ end
8282
# `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
8383
@inline Base.copyto!(dest::GPUDestArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
8484
copyto!(dest, convert(Broadcasted{Nothing}, bc))
85+
86+
87+
## map
88+
89+
allequal(x) = true
90+
allequal(x, y, z...) = x == y && allequal(y, z...)
91+
92+
function Base.map!(f, y::AbstractGPUArray, xs::AbstractArray...)
93+
@assert allequal(size.((y, xs...))...)
94+
return y .= f.(xs...)
95+
end
96+
97+
function Base.map(f, y::AbstractGPUArray, xs::AbstractArray...)
98+
@assert allequal(size.((y, xs...))...)
99+
return f.(y, xs...)
100+
end

0 commit comments

Comments
 (0)