Skip to content

Commit 2ab7554

Browse files
committed
BFloat16 extension
1 parent 558ccfa commit 2ab7554

File tree

5 files changed

+22
-5
lines changed

5 files changed

+22
-5
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.0.0"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
9+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
910
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
1011
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
1112
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -28,14 +29,17 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2829
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2930

3031
[weakdeps]
32+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
3133
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3234

3335
[extensions]
3436
SpecialFunctionsExt = "SpecialFunctions"
37+
BFloat16sExt = "BFloat16s"
3538

3639
[compat]
3740
Adapt = "4"
3841
Artifacts = "1"
42+
BFloat16s = "0.5"
3943
CEnum = "0.4, 0.5"
4044
CodecBzip2 = "0.8"
4145
ExprTools = "0.1"

ext/BFloat16sExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module BFloat16sExt
2+
3+
using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version
4+
using BFloat16s
5+
6+
# BFloat is only supported in MPS starting in MacOS 14
7+
if macos_version() >= v"14"
8+
Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16
9+
jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16
10+
end
11+
12+
end # module

lib/mps/matrix.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,C
4343
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
4444
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
4545
end
46-
# BFloat is only supported in MPS starting in MacOS 14
47-
if macos_version() >= v"14" && isdefined(Core, :BFloat16)
48-
Base.convert(::Type{MPSDataType}, ::Type{Core.BFloat16}) = MPSDataTypeBFloat16
49-
jl_mps_to_typ[MPSDataTypeBFloat16] = Core.BFloat16
50-
end
46+
5147
Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]
5248

5349

src/Metal.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,8 @@ include("MetalKernels.jl")
6767
import .MetalKernels: MetalBackend
6868
export MetalBackend
6969

70+
@static if !isdefined(Base, :get_extension)
71+
include("../ext/BFloat16sExt.jl")
72+
end
73+
7074
end # module

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
34
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
45
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
56
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

0 commit comments

Comments
 (0)