Skip to content

Commit 356ee6c

Browse files
adapt_storage-related improvements (#296)
* Check for Int128 compatibility * Remove unneeded adapt_storage definitions and add tests. * Add forgotten adaptor from parameterizing storage mode and test
1 parent 81f716c commit 356ee6c

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

src/array.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function check_eltype(T)
3030
Base.allocatedinline(T) || error("MtlArray only supports element types that are stored inline")
3131
Base.isbitsunion(T) && error("MtlArray does not yet support isbits-union arrays")
3232
contains_eltype(T, Float64) && error("Metal does not support Float64 values, try using Float32 instead")
33+
contains_eltype(T, Int128) && error("Metal does not support Int128 values, try using Int64 instead")
34+
contains_eltype(T, UInt128) && error("Metal does not support UInt128 values, try using UInt64 instead")
3335
end
3436

3537
"""
@@ -314,6 +316,8 @@ Adapt.adapt_storage(::Type{<:MtlArray{T}}, xs::AT) where {T, AT<:AbstractArray}
314316
isbitstype(AT) ? xs : convert(MtlArray{T}, xs)
315317
Adapt.adapt_storage(::Type{<:MtlArray{T, N}}, xs::AT) where {T, N, AT<:AbstractArray} =
316318
isbitstype(AT) ? xs : convert(MtlArray{T,N}, xs)
319+
Adapt.adapt_storage(::Type{<:MtlArray{T, N, S}}, xs::AT) where {T, N, S, AT<:AbstractArray} =
320+
isbitstype(AT) ? xs : convert(MtlArray{T,N,S}, xs)
317321

318322

319323
## opinionated gpu array adaptor
@@ -325,19 +329,12 @@ struct MtlArrayAdaptor{S} end
325329
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T,N,S} =
326330
isbits(xs) ? xs : MtlArray{T,N,S}(xs)
327331

328-
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:AbstractFloat,N,S} =
332+
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float64,N,S} =
329333
isbits(xs) ? xs : MtlArray{Float32,N,S}(xs)
330334

331-
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:AbstractFloat},N,S} =
335+
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{<:Float64},N,S} =
332336
isbits(xs) ? xs : MtlArray{ComplexF32,N,S}(xs)
333337

334-
# not for Float16
335-
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Float16,N,S} =
336-
isbits(xs) ? xs : MtlArray{T,N,S}(xs)
337-
338-
Adapt.adapt_storage(::MtlArrayAdaptor{S}, xs::AbstractArray{T,N}) where {T<:Complex{Float16},N,S} =
339-
isbits(xs) ? xs : MtlArray{T,N,S}(xs)
340-
341338
"""
342339
mtl(A; storage=Private)
343340

test/array.jl

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,40 @@ end
2222
@test Base.elsize(xs) == sizeof(Int)
2323
@test pointer(MtlArray{Int, 2}(xs)) != pointer(xs)
2424

25-
# test aggressive conversion to Float32, but only for floats, and only with `mtl`
26-
@test mtl([1]) isa MtlArray{Int}
27-
@test mtl(Float64[1]) isa MtlArray{Float32}
28-
@test mtl(ComplexF64[1+1im]) isa MtlArray{ComplexF32}
29-
@test mtl(ComplexF16[1+1im]) isa MtlArray{ComplexF16}
25+
# Page 22 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
26+
# Only bfloat missing
27+
supported_number_types = [Float16 => Float16,
28+
Float32 => Float32,
29+
Float64 => Float32,
30+
Bool => Bool,
31+
Int16 => Int16,
32+
Int32 => Int32,
33+
Int64 => Int64,
34+
Int8 => Int8,
35+
UInt16 => UInt16,
36+
UInt32 => UInt32,
37+
UInt64 => UInt64,
38+
UInt8 => UInt8]
39+
# Test supported types and ensure only Float64 get converted to Float32
40+
for (SrcType, TargType) in supported_number_types
41+
@test mtl(SrcType[1]) isa MtlArray{TargType}
42+
@test mtl(Complex{SrcType}[1+1im]) isa MtlArray{Complex{TargType}}
43+
end
44+
45+
# test the regular adaptor
46+
@test Adapt.adapt(MtlArray, [1 2;3 4]) isa MtlArray{Int, 2, Private}
47+
@test Adapt.adapt(MtlArray{Float32}, [1 2;3 4]) isa MtlArray{Float32, 2, Private}
48+
@test Adapt.adapt(MtlArray{Float32, 2}, [1 2;3 4]) isa MtlArray{Float32, 2, Private}
49+
@test Adapt.adapt(MtlArray{Float32, 2, Shared}, [1 2;3 4]) isa MtlArray{Float32, 2, Shared}
50+
@test Adapt.adapt(MtlMatrix{ComplexF32, Shared}, [1 2;3 4]) isa MtlArray{ComplexF32, 2, Shared}
3051
@test Adapt.adapt(MtlArray{Float16}, Float64[1]) isa MtlArray{Float16}
3152

53+
# Test a few explicitly unsupported types
54+
@test_throws "MtlArray only supports element types that are stored inline" MtlArray(BigInt[1])
55+
@test_throws "MtlArray only supports element types that are stored inline" MtlArray(BigFloat[1])
56+
@test_throws "Metal does not support Float64 values" MtlArray(Float64[1])
57+
@test_throws "Metal does not support Int128 values" MtlArray(Int128[1])
58+
@test_throws "Metal does not support UInt128 values" MtlArray(UInt128[1])
3259

3360
@test collect(Metal.zeros(2, 2)) == zeros(Float32, 2, 2)
3461
@test collect(Metal.ones(2, 2)) == ones(Float32, 2, 2)

0 commit comments

Comments
 (0)