diff --git a/Project.toml b/Project.toml index fe37ad4f4..6443a80b2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.5.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" @@ -23,11 +24,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] -BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] -BFloat16sExt = "BFloat16s" SpecialFunctionsExt = "SpecialFunctions" [compat] diff --git a/ext/BFloat16sExt.jl b/ext/BFloat16sExt.jl deleted file mode 100644 index 6f11b18fc..000000000 --- a/ext/BFloat16sExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -module BFloat16sExt - -using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version -using BFloat16s - -# BFloat is only supported in MPS starting in MacOS 14 -@static if Sys.isapple() - if macos_version() >= v"14" - Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16 - jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16 - end -end - -end # module diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index eb4312863..e357c43c2 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -16,6 +16,8 @@ using ObjectiveC, .Foundation import GPUArrays +using BFloat16s + const MtlFloat = Union{Float32, Float16} const MPSShape = NSArray#{NSNumber} diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index c601ff30a..f1b156cd2 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -5,8 +5,12 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) # Conversions for MPSDataTypes with Julia equivalents const jl_mps_to_typ = Dict{MPSDataType, DataType}() -for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,(ComplexF16,:MPSDataTypeComplexFloat16),(ComplexF32,:MPSDataTypeComplexFloat32),Bool] - jltype, mpstype = if type isa Type +for type in [ + :Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64, + :Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16), + (:ComplexF32, :MPSDataTypeComplexFloat32), + ] + jltype, mpstype = if type isa Symbol type, Symbol(:MPSDataType, type) else type