Skip to content

Commit d81fd4c

Browse files
Handle broadcasting when storage types are different (#605)
Co-authored-by: Christian Guinard <[email protected]>
1 parent 3e9bbac commit d81fd4c

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

src/broadcast.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@ BroadcastStyle(W::Type{<:WrappedMtlArray{T,N}}) where {T,N} =
1212

1313
# when we are dealing with different buffer styles, we cannot know
1414
# which one is better, so use shared memory
15-
BroadcastStyle(::MtlArrayStyle{N, S1},
16-
::MtlArrayStyle{N, S2}) where {N,S1,S2} =
17-
MtlArrayStyle{N, SharedStorage}()
15+
BroadcastStyle(::MtlArrayStyle{N1, S1},
16+
::MtlArrayStyle{N2, S2}) where {N1,N2,S1,S2} =
17+
MtlArrayStyle{max(N1, N2), SharedStorage}()
18+
19+
# resolve ambiguity: different N, same memory type
20+
BroadcastStyle(::MtlArrayStyle{N1, S},
21+
::MtlArrayStyle{N2, S}) where {N1,N2,S} =
22+
MtlArrayStyle{max(N1, N2), S}()
1823

1924
# allocation of output arrays
2025
Base.similar(::Broadcasted{MtlArrayStyle{N, S}}, ::Type{T}, dims) where {T, N, S} =

test/array.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,24 @@ end
557557
@test testf(x->min.(x, one(Float32)), randn(Float32, 1000))
558558
@test testf(x->min.(max.(x, zero(Float32)), one(Float32)), randn(Float32, 1000))
559559
@test testf(x->max.(min.(x, one(Float32)), zero(Float32)), randn(Float32, 1000))
560+
561+
# preserving buffer types
562+
let x = Metal.zeros(Float32, 1; storage=Metal.SharedStorage)
563+
y = x .+ 1
564+
@test is_shared(y)
565+
end
566+
567+
# when storages are different, choose shared
568+
let x = Metal.zeros(Float32, 1; storage=Metal.SharedStorage), y = Metal.zeros(Float32, 1; storage=Metal.PrivateStorage)
569+
z = x .+ y
570+
@test is_shared(z)
571+
end
572+
573+
let x = Metal.zeros(Float32, 2, 2; storage=Metal.SharedStorage), y = Metal.zeros(Float32, 2; storage=Metal.PrivateStorage)
574+
z = x .+ y
575+
@test is_shared(z)
576+
end
560577
end
561578

579+
562580
end

0 commit comments

Comments
 (0)