Skip to content

Commit dc49c29

Browse files
mtfishmanmaleadt
andauthored
Preserve 0-dim arrays in map (#599)
Co-authored-by: Tim Besard <[email protected]>
1 parent 0093468 commit dc49c29

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GPUArrays"
22
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3-
version = "11.2.2"
3+
version = "11.2.3"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/host/broadcast.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@ end
8282
allequal(x) = true
8383
allequal(x, y, z...) = x == y && allequal(y, z...)
8484

85-
function Base.map(f, xs::AnyGPUArray...)
85+
function Base.map(f, x1::AnyGPUArray, xrest::AnyGPUArray...)
86+
xs = (x1, xrest...)
8687
# if argument sizes match, their shape needs to be preserved
8788
if allequal(size.(xs)...)
88-
return f.(xs...)
89+
return Broadcast.broadcast_preserving_zero_d(f, xs...)
8990
end
9091

9192
# if not, treat them as iterators

test/testsuite/broadcasting.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ function broadcasting(AT, eltypes)
158158
@test compare(AT, rand(ET, 2,2), rand(ET, 2)) do x,y
159159
map(+, x, y)
160160
end
161+
############
162+
# issue #598
163+
@test compare(AT, rand(ET, ()), rand(ET, ())) do x, y
164+
map(+, x, y)
165+
end
161166
end
162167
end
163168

0 commit comments

Comments
 (0)