@@ -30,6 +30,8 @@ function check_eltype(T)
3030 Base. allocatedinline (T) || error (" MtlArray only supports element types that are stored inline" )
3131 Base. isbitsunion (T) && error (" MtlArray does not yet support isbits-union arrays" )
3232 contains_eltype (T, Float64) && error (" Metal does not support Float64 values, try using Float32 instead" )
33+ contains_eltype (T, Int128) && error (" Metal does not support Int128 values, try using Int64 instead" )
34+ contains_eltype (T, UInt128) && error (" Metal does not support UInt128 values, try using UInt64 instead" )
3335end
3436
3537"""
@@ -314,6 +316,8 @@ Adapt.adapt_storage(::Type{<:MtlArray{T}}, xs::AT) where {T, AT<:AbstractArray}
314316 isbitstype (AT) ? xs : convert (MtlArray{T}, xs)
315317Adapt. adapt_storage (:: Type{<:MtlArray{T, N}} , xs:: AT ) where {T, N, AT<: AbstractArray } =
316318 isbitstype (AT) ? xs : convert (MtlArray{T,N}, xs)
319+ Adapt. adapt_storage (:: Type{<:MtlArray{T, N, S}} , xs:: AT ) where {T, N, S, AT<: AbstractArray } =
320+ isbitstype (AT) ? xs : convert (MtlArray{T,N,S}, xs)
317321
318322
319323# # opinionated gpu array adaptor
@@ -325,19 +329,12 @@ struct MtlArrayAdaptor{S} end
325329Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T,N,S} =
326330 isbits (xs) ? xs : MtlArray {T,N,S} (xs)
327331
328- Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T<: AbstractFloat ,N,S} =
332+ Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T<: Float64 ,N,S} =
329333 isbits (xs) ? xs : MtlArray {Float32,N,S} (xs)
330334
331- Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T<: Complex{<:AbstractFloat } ,N,S} =
335+ Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T<: Complex{<:Float64 } ,N,S} =
332336 isbits (xs) ? xs : MtlArray {ComplexF32,N,S} (xs)
333337
334- # not for Float16
335- Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T<: Float16 ,N,S} =
336- isbits (xs) ? xs : MtlArray {T,N,S} (xs)
337-
338- Adapt. adapt_storage (:: MtlArrayAdaptor{S} , xs:: AbstractArray{T,N} ) where {T<: Complex{Float16} ,N,S} =
339- isbits (xs) ? xs : MtlArray {T,N,S} (xs)
340-
341338"""
342339 mtl(A; storage=Private)
343340
0 commit comments