Skip to content

Commit 91fbbac

Browse files
authored
Merge pull request #245 from JuliaGPU/tb/map
Simplify map dispatch.
2 parents 02b3fb8 + b47b8c7 commit 91fbbac

File tree

6 files changed

+33
-56
lines changed

6 files changed

+33
-56
lines changed

.gitlab-ci.yml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,6 @@ include:
44

55
# Julia versions
66

7-
julia:1.0:
8-
extends:
9-
- .julia:1.0
10-
- .test
11-
12-
julia:1.1:
13-
extends:
14-
- .julia:1.1
15-
- .test
16-
17-
julia:1.2:
18-
extends:
19-
- .julia:1.2
20-
- .test
21-
227
julia:1.3:
238
extends:
249
- .julia:1.3

.travis.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ os:
88
dist: trusty
99

1010
julia:
11-
- 1.0
12-
- 1.1
13-
- 1.2
1411
- 1.3
1512
- nightly
1613

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1313
[compat]
1414
AbstractFFTs = "0.4, 0.5"
1515
Adapt = "0.4.1, 1.0"
16-
julia = "1.0"
16+
julia = "1.3"
1717

1818
[extras]
1919
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

src/host/base.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,5 @@
11
# common Base functionality
22

3-
allequal(x) = true
4-
allequal(x, y, z...) = x == y && allequal(y, z...)
5-
function Base.map!(f, y::AbstractGPUArray, xs::AbstractGPUArray...)
6-
@assert allequal(size.((y, xs...))...)
7-
return y .= f.(xs...)
8-
end
9-
function Base.map(f, y::AbstractGPUArray, xs::AbstractGPUArray...)
10-
@assert allequal(size.((y, xs...))...)
11-
return f.(y, xs...)
12-
end
13-
14-
# Break ambiguities with base
15-
Base.map!(f, y::AbstractGPUArray) =
16-
invoke(map!, Tuple{Any,AbstractGPUArray,Vararg{AbstractGPUArray}}, f, y)
17-
Base.map!(f, y::AbstractGPUArray, x::AbstractGPUArray) =
18-
invoke(map!, Tuple{Any,AbstractGPUArray, Vararg{AbstractGPUArray}}, f, y, x)
19-
Base.map!(f, y::AbstractGPUArray, x1::AbstractGPUArray, x2::AbstractGPUArray) =
20-
invoke(map!, Tuple{Any,AbstractGPUArray, Vararg{AbstractGPUArray}}, f, y, x1, x2)
21-
223
function Base.repeat(a::AbstractGPUVecOrMat, m::Int, n::Int = 1)
234
o, p = size(a, 1), size(a, 2)
245
b = similar(a, o*m, p*n)

src/host/broadcast.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,19 @@ end
8282
# `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
8383
@inline Base.copyto!(dest::GPUDestArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
8484
copyto!(dest, convert(Broadcasted{Nothing}, bc))
85+
86+
87+
## map
88+
89+
allequal(x) = true
90+
allequal(x, y, z...) = x == y && allequal(y, z...)
91+
92+
function Base.map!(f, y::AbstractGPUArray, xs::AbstractArray...)
93+
@assert allequal(size.((y, xs...))...)
94+
return y .= f.(xs...)
95+
end
96+
97+
function Base.map(f, y::AbstractGPUArray, xs::AbstractArray...)
98+
@assert allequal(size.((y, xs...))...)
99+
return f.(y, xs...)
100+
end

src/host/quirks.jl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,23 @@
33
# NOTE: without contextual dispatch, we can only redefine methods where a GPU-specific
44
# type occurs in the signature (or we'll get a "fatal precompilation failure" error)
55

6-
if VERSION >= v"1.3.0-alpha.107"
7-
_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))))
8-
_bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
9-
_bcs1(a, b::Integer) = _bcs1(b, a)
10-
_bcs1(a, b) = Broadcast._bcsm(b, a) ? Broadcast.axistype(b, a) : (Broadcast._bcsm(a, b) ? Broadcast.axistype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
6+
_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))))
7+
_bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
8+
_bcs1(a, b::Integer) = _bcs1(b, a)
9+
_bcs1(a, b) = Broadcast._bcsm(b, a) ? Broadcast.axistype(b, a) : (Broadcast._bcsm(a, b) ? Broadcast.axistype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
1110

12-
_bcs(::Tuple{}, ::Tuple{}) = ()
13-
_bcs(::Tuple{}, newshape::Tuple) = (newshape[1], _bcs((), Base.tail(newshape))...)
14-
_bcs(shape::Tuple, ::Tuple{}) = (shape[1], _bcs(Base.tail(shape), ())...)
15-
function _bcs(shape::Tuple, newshape::Tuple)
16-
return (_bcs1(shape[1], newshape[1]), _bcs(Base.tail(shape), Base.tail(newshape))...)
17-
end
11+
_bcs(::Tuple{}, ::Tuple{}) = ()
12+
_bcs(::Tuple{}, newshape::Tuple) = (newshape[1], _bcs((), Base.tail(newshape))...)
13+
_bcs(shape::Tuple, ::Tuple{}) = (shape[1], _bcs(Base.tail(shape), ())...)
14+
function _bcs(shape::Tuple, newshape::Tuple)
15+
return (_bcs1(shape[1], newshape[1]), _bcs(Base.tail(shape), Base.tail(newshape))...)
16+
end
1817

19-
broadcast_shape(shape::Tuple) = shape
20-
broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs(shape, shape1), shapes...)
18+
broadcast_shape(shape::Tuple) = shape
19+
broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs(shape, shape1), shapes...)
2120

22-
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
23-
combine_axes(A) = axes(A)
21+
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
22+
combine_axes(A) = axes(A)
2423

25-
Broadcast._axes(::Broadcasted{<:AbstractGPUArrayStyle}, axes::Tuple) = axes
26-
@inline Broadcast._axes(bc::Broadcasted{<:AbstractGPUArrayStyle}, ::Nothing) = combine_axes(bc.args...)
27-
end
24+
Broadcast._axes(::Broadcasted{<:AbstractGPUArrayStyle}, axes::Tuple) = axes
25+
@inline Broadcast._axes(bc::Broadcasted{<:AbstractGPUArrayStyle}, ::Nothing) = combine_axes(bc.args...)

0 commit comments

Comments
 (0)