Skip to content

Commit 16dc24d

Browse files
committed
Add an AbstractGPUArrayStyle that preserves dimensionality.
1 parent cf699ef commit 16dc24d

File tree

5 files changed

+32
-22
lines changed

5 files changed

+32
-22
lines changed

src/host/broadcast.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
11
# broadcasting operations
22

3+
export AbstractGPUArrayStyle
4+
35
using Base.Broadcast
46

5-
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
7+
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
68

7-
# we define a generic `BroadcastStyle` here that should be sufficient for most cases.
8-
# dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
9-
# them to further change or optimize broadcasting.
10-
#
11-
# TODO: investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
12-
#
13-
# NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
14-
# instead of using `ArrayStyle{AbstractGPUArray}`, due to the fact how `similar` works.
15-
BroadcastStyle(::Type{T}) where {T<:AbstractGPUArray} = ArrayStyle{T}()
9+
"""
10+
Abstract supertype for GPU array styles. The `N` parameter is the dimensionality.
11+
12+
Downstream implementations should provide a concrete array style type that inherits from
13+
this supertype.
14+
"""
15+
abstract type AbstractGPUArrayStyle{N} <: AbstractArrayStyle{N} end
1616

1717
# Wrapper types otherwise forget that they are GPU compatible
18-
#
19-
# NOTE: Don't directly use ArrayStyle{AbstractGPUArray} here since that would mean that `CuArrays`
20-
# customization no longer take effect.
18+
# NOTE: don't directly use GPUArrayStyle here not to lose downstream customizations.
2119
for (W, ctor) in Adapt.wrappers
2220
@eval begin
2321
BroadcastStyle(::Type{<:$W}) where {AT<:AbstractGPUArray} = BroadcastStyle(AT)

src/host/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
5757
gpu_promote_type(::typeof(abs), ::Type{Complex{T}}) where {T} = T
5858
gpu_promote_type(::typeof(abs2), ::Type{Complex{T}}) where {T} = T
5959

60-
import Base.Broadcast: Broadcasted, ArrayStyle
61-
const GPUSrcArray = Union{Broadcasted{ArrayStyle{AT}}, AbstractGPUArray{T, N}} where {T, N, AT<:AbstractGPUArray}
60+
import Base.Broadcast: Broadcasted
61+
const GPUSrcArray = Union{Broadcasted{<:AbstractGPUArrayStyle}, <:AbstractGPUArray}
6262

6363
function Base.mapreduce(f::Function, op::Function, A::GPUSrcArray; dims = :, init...)
6464
mapreduce_impl(f, op, init.data, A, dims)

src/host/quirks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ if VERSION >= v"1.3.0-alpha.107"
2222
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
2323
combine_axes(A) = axes(A)
2424

25-
Broadcast._axes(::Broadcasted{ArrayStyle{AT}}, axes::Tuple) where {AT <: AbstractGPUArray} = axes
26-
@inline Broadcast._axes(bc::Broadcasted{ArrayStyle{AT}}, ::Nothing) where {AT <: AbstractGPUArray} = combine_axes(bc.args...)
25+
Broadcast._axes(::Broadcasted{<:AbstractGPUArrayStyle}, axes::Tuple) = axes
26+
@inline Broadcast._axes(bc::Broadcasted{<:AbstractGPUArrayStyle}, ::Nothing) = combine_axes(bc.args...)
2727
end

src/reference.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,19 @@ Base.convert(::Type{T}, x::T) where T <: JLArray = x
193193

194194
## broadcast
195195

196-
using Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
196+
using Base.Broadcast: BroadcastStyle, Broadcasted
197197

198-
BroadcastStyle(::Type{<:JLArray}) = ArrayStyle{JLArray}()
198+
struct JLArrayStyle{N} <: AbstractGPUArrayStyle{N} end
199+
JLArrayStyle(::Val{N}) where N = JLArrayStyle{N}()
200+
JLArrayStyle{M}(::Val{N}) where {N,M} = JLArrayStyle{N}()
199201

200-
function Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}) where T
202+
BroadcastStyle(::Type{JLArray{T,N}}) where {T,N} = JLArrayStyle{N}()
203+
204+
Base.similar(bc::Broadcasted{JLArrayStyle{N}}, ::Type{T}) where {N,T} =
201205
similar(JLArray{T}, axes(bc))
202-
end
203206

204-
Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}, dims...) where {T} = JLArray{T}(undef, dims...)
207+
Base.similar(bc::Broadcasted{JLArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
208+
JLArray{T}(undef, dims...)
205209

206210

207211
## memory operations

test/testsuite/broadcasting.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ function broadcasting(AT)
118118
@test compare((A, B) -> A .* B .+ ET(10), AT, rand(ET, 40, 40), rand(ET, 40, 40))
119119
end
120120
end
121+
122+
@testset "0D" begin
123+
x = AT{Float64}(undef)
124+
x .= 1
125+
@test collect(x)[] == 1
126+
x /= 2
127+
@test collect(x)[] == 0.5
128+
end
121129
end
122130

123131
function vec3(AT)

0 commit comments

Comments
 (0)