|
3 | 3 | # NOTE: without contextual dispatch, we can only redefine methods where a GPU-specific |
4 | 4 | # type occurs in the signature (or we'll get a "fatal precompilation failure" error) |
5 | 5 |
|
6 | | -if VERSION >= v"1.3.0-alpha.107" |
7 | | - _bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))) |
8 | | - _bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) |
9 | | - _bcs1(a, b::Integer) = _bcs1(b, a) |
10 | | - _bcs1(a, b) = Broadcast._bcsm(b, a) ? Broadcast.axistype(b, a) : (Broadcast._bcsm(a, b) ? Broadcast.axistype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) |
| 6 | +_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))) |
| 7 | +_bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) |
| 8 | +_bcs1(a, b::Integer) = _bcs1(b, a) |
| 9 | +_bcs1(a, b) = Broadcast._bcsm(b, a) ? Broadcast.axistype(b, a) : (Broadcast._bcsm(a, b) ? Broadcast.axistype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) |
11 | 10 |
|
12 | | - _bcs(::Tuple{}, ::Tuple{}) = () |
13 | | - _bcs(::Tuple{}, newshape::Tuple) = (newshape[1], _bcs((), Base.tail(newshape))...) |
14 | | - _bcs(shape::Tuple, ::Tuple{}) = (shape[1], _bcs(Base.tail(shape), ())...) |
15 | | - function _bcs(shape::Tuple, newshape::Tuple) |
16 | | - return (_bcs1(shape[1], newshape[1]), _bcs(Base.tail(shape), Base.tail(newshape))...) |
17 | | - end |
| 11 | +_bcs(::Tuple{}, ::Tuple{}) = () |
| 12 | +_bcs(::Tuple{}, newshape::Tuple) = (newshape[1], _bcs((), Base.tail(newshape))...) |
| 13 | +_bcs(shape::Tuple, ::Tuple{}) = (shape[1], _bcs(Base.tail(shape), ())...) |
| 14 | +function _bcs(shape::Tuple, newshape::Tuple) |
| 15 | + return (_bcs1(shape[1], newshape[1]), _bcs(Base.tail(shape), Base.tail(newshape))...) |
| 16 | +end |
18 | 17 |
|
19 | | - broadcast_shape(shape::Tuple) = shape |
20 | | - broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs(shape, shape1), shapes...) |
| 18 | +broadcast_shape(shape::Tuple) = shape |
| 19 | +broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs(shape, shape1), shapes...) |
21 | 20 |
|
22 | | - @inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...)) |
23 | | - combine_axes(A) = axes(A) |
| 21 | +@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...)) |
| 22 | +combine_axes(A) = axes(A) |
24 | 23 |
|
25 | | - Broadcast._axes(::Broadcasted{<:AbstractGPUArrayStyle}, axes::Tuple) = axes |
26 | | - @inline Broadcast._axes(bc::Broadcasted{<:AbstractGPUArrayStyle}, ::Nothing) = combine_axes(bc.args...) |
27 | | -end |
| 24 | +Broadcast._axes(::Broadcasted{<:AbstractGPUArrayStyle}, axes::Tuple) = axes |
| 25 | +@inline Broadcast._axes(bc::Broadcasted{<:AbstractGPUArrayStyle}, ::Nothing) = combine_axes(bc.args...) |
0 commit comments