Skip to content

Commit a387890

Browse files
authored
Fix broadcasting for differently backed CLArrays of different dimensions (#345)
1 parent b725007 commit a387890

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ BroadcastStyle(W::Type{<:WrappedCLArray{T, N}}) where {T, N} =
1313
# when we are dealing with different buffer styles, we cannot know
1414
# which one is better, so use shared memory
1515
BroadcastStyle(
16-
::CLArrayStyle{N, B1},
16+
::CLArrayStyle{M, B1},
1717
::CLArrayStyle{N, B2}
18-
) where {N, B1, B2} =
19-
CLArrayStyle{N, cl.UnifiedSharedMemory}()
18+
) where {M, N, B1, B2} =
19+
CLArrayStyle{max(M, N), B1 == B2 ? B1 : cl.UnifiedSharedMemory}()
2020

2121
# allocation of output arrays
2222
Base.similar(bc::Broadcasted{CLArrayStyle{N, B}}, ::Type{T}, dims) where {T, N, B} =

test/execution.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,22 @@ a = CLArray{Int}(undef, 10)
115115

116116
end
117117

118+
@testset "broadcasting" begin
119+
a = rand(Float32, 2, 3)
120+
b = rand(Float32, 2)
121+
122+
c = a .+ b
123+
a_cl, b_cl = CLArray(a), CLArray(b)
124+
c_cl = a_cl .+ b_cl
125+
@test Array(c_cl) == c
126+
@test c_cl isa CLArray{Float32, 2, OpenCL.memory_type()}
127+
128+
if cl.usm_supported(cl.device())
129+
a_cl, b_cl = CLMatrix{Float32, cl.UnifiedSharedMemory}(a), CLVector{Float32, OpenCL.memory_type()}(b)
130+
c_cl = a_cl .+ b_cl
131+
@test Array(c_cl) == c
132+
@test c_cl isa CLArray{Float32, 2, cl.UnifiedSharedMemory}
133+
end
134+
end
135+
118136
end

0 commit comments

Comments
 (0)