Skip to content

Commit ab7ac23

Browse files
authored
Merge pull request #207 from JuliaGPU/tb/axes
Remove hack and properly reimplement Broadcast._bcs1.
2 parents 5462b74 + 1d1416b commit ab7ac23

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

src/GPUArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,6 @@ include("array.jl")
3838

3939
include("testsuite.jl")
4040

41+
include("quirks.jl")
42+
4143
end # module

src/mapreduce.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,13 @@ for i = 0:10
131131
fargs = ntuple(x-> :(simple_broadcast_index($(args[x]), cartesian_global_index...)), i)
132132
@eval begin
133133
# http://developer.amd.com/resources/articles-whitepapers/opencl-optimization-case-study-simple-reductions/
134-
function reduce_kernel(state, f, op, v0::T, A, len, ax, ::Val{LMEM}, result, $(args...)) where {T, LMEM}
134+
function reduce_kernel(state, f, op, v0::T, A, ::Val{LMEM}, result, $(args...)) where {T, LMEM}
135135
tmp_local = @LocalMemory(state, T, LMEM)
136136
global_index = linear_index(state)
137137
acc = v0
138138
# # Loop sequentially over chunks of input vector
139-
# HACK: length(A) and axes(A) aren't GPU compatible, so pass them instead
140-
# https://github.com/JuliaGPU/CUDAnative.jl/issues/367
141-
@inbounds while global_index <= len
142-
cartesian_global_index = Tuple(CartesianIndices(ax)[global_index])
139+
@inbounds while global_index <= length(A)
140+
cartesian_global_index = Tuple(CartesianIndices(axes(A))[global_index])
143141
@inbounds element = f(A[cartesian_global_index...], $(fargs...))
144142
acc = op(acc, element)
145143
global_index += global_size(state)
@@ -184,7 +182,7 @@ function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest::Tuple) where {OT}
184182
end
185183
out = similar(A, OT, (blocksize,))
186184
fill!(out, v0)
187-
args = (f, op, v0, A, length(A), axes(A), Val{threads}(), out, rest...)
185+
args = (f, op, v0, A, Val{threads}(), out, rest...)
188186
gpu_call(reduce_kernel, out, args, ((blocksize,), (threads,)))
189187
reduce(op, Array(out))
190188
end

src/quirks.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# revert JuliaLang/julia#32867; avoid string interpolation
2+
#
3+
# NOTE: without contextual dispatch, we can only redefine methods where a GPU-specific
4+
# type occurs in the signature (or we'll get a "fatal precompilation failure" error)
5+
6+
if VERSION >= v"1.3.0-alpha.107"
7+
_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))))
8+
_bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
9+
_bcs1(a, b::Integer) = _bcs1(b, a)
10+
_bcs1(a, b) = Broadcast._bcsm(b, a) ? Broadcast.axistype(b, a) : (Broadcast._bcsm(a, b) ? Broadcast.axistype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
11+
12+
_bcs(::Tuple{}, ::Tuple{}) = ()
13+
_bcs(::Tuple{}, newshape::Tuple) = (newshape[1], _bcs((), Base.tail(newshape))...)
14+
_bcs(shape::Tuple, ::Tuple{}) = (shape[1], _bcs(Base.tail(shape), ())...)
15+
function _bcs(shape::Tuple, newshape::Tuple)
16+
return (_bcs1(shape[1], newshape[1]), _bcs(Base.tail(shape), Base.tail(newshape))...)
17+
end
18+
19+
broadcast_shape(shape::Tuple) = shape
20+
broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs(shape, shape1), shapes...)
21+
22+
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
23+
combine_axes(A) = axes(A)
24+
25+
Broadcast._axes(::Broadcasted{ArrayStyle{AT}}, axes::Tuple) where {AT <: GPUArray} = axes
26+
@inline Broadcast._axes(bc::Broadcasted{ArrayStyle{AT}}, ::Nothing) where {AT <: GPUArray} = combine_axes(bc.args...)
27+
end

0 commit comments

Comments
 (0)