Skip to content

Commit e057d13

Browse files
authored
Merge pull request #326 from christiangnrd/improvements
BFloat16s.jl extension and related improvements
2 parents bc2131e + 0fd9eaa commit e057d13

File tree

8 files changed

+63
-32
lines changed

8 files changed

+63
-32
lines changed

.buildkite/pipeline.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@ steps:
1919
queue: "juliaecosystem"
2020
os: "macos"
2121
arch: "aarch64"
22-
commands: |
23-
julia --project -e '
24-
# make sure the 1.8-era Manifest works on this Julia version
25-
using Pkg
26-
Pkg.resolve()'
2722
if: build.message !~ /\[skip tests\]/
2823
timeout_in_minutes: 60
2924
matrix:

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.0.0"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
9+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
910
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
1011
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
1112
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -28,14 +29,17 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2829
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2930

3031
[weakdeps]
32+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
3133
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3234

3335
[extensions]
3436
SpecialFunctionsExt = "SpecialFunctions"
37+
BFloat16sExt = "BFloat16s"
3538

3639
[compat]
3740
Adapt = "4"
3841
Artifacts = "1"
42+
BFloat16s = "0.5"
3943
CEnum = "0.4, 0.5"
4044
CodecBzip2 = "0.8"
4145
ExprTools = "0.1"

ext/BFloat16sExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module BFloat16sExt
2+
3+
using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version
4+
using BFloat16s
5+
6+
# BFloat is only supported in MPS starting in MacOS 14
7+
if macos_version() >= v"14"
8+
Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16
9+
jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16
10+
end
11+
12+
end # module

lib/mps/matrix.jl

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,52 @@
11
#
22
# matrix enums
33
#
4-
5-
@cenum MPSDataType::UInt32 begin
4+
@cenum MPSDataTypeBits::UInt32 begin
65
MPSDataTypeComplexBit = UInt32(0x01000000)
76
MPSDataTypeFloatBit = UInt32(0x10000000)
87
MPSDataTypeSignedBit = UInt32(0x20000000)
98
MPSDataTypeNormalizedBit = UInt32(0x40000000)
109
MPSDataTypeAlternateEncodingBit = UInt32(0x80000000)
1110
end
11+
12+
@enum MPSDataType::UInt32 begin
13+
MPSDataTypeInvalid = UInt32(0)
14+
15+
MPSDataTypeUInt8 = UInt32(8)
16+
MPSDataTypeUInt16 = UInt32(16)
17+
MPSDataTypeUInt32 = UInt32(32)
18+
MPSDataTypeUInt64 = UInt32(64)
19+
20+
MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8)
21+
MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16)
22+
MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32)
23+
MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64)
24+
25+
MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16)
26+
MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32)
27+
28+
MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16)
29+
MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
30+
31+
MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1)
32+
MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8)
33+
34+
MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8)
35+
MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16)
36+
end
1237
## bitwise operations lose type information, so allow conversions
1338
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
1439

40+
# Conversions for MPSDataTypes with Julia equivalents
41+
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
42+
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
43+
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
44+
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
45+
end
46+
47+
Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]
48+
49+
1550
#
1651
# matrix descriptor
1752
#
@@ -29,31 +64,11 @@ export MPSMatrixDescriptor
2964
@autoproperty matrixBytes::NSUInteger
3065
end
3166

32-
33-
# Mapping from Julia types to the Performance Shader bitfields
34-
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
35-
UInt8 => UInt32(8),
36-
UInt16 => UInt32(16),
37-
UInt32 => UInt32(32),
38-
UInt64 => UInt32(64),
39-
40-
Int8 => MPSDataTypeSignedBit | UInt32(8),
41-
Int16 => MPSDataTypeSignedBit | UInt32(16),
42-
Int32 => MPSDataTypeSignedBit | UInt32(32),
43-
Int64 => MPSDataTypeSignedBit | UInt32(64),
44-
45-
Float16 => MPSDataTypeFloatBit | UInt32(16),
46-
Float32 => MPSDataTypeFloatBit | UInt32(32),
47-
48-
ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16),
49-
ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
50-
)
51-
5267
function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
5368
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
5469
columns:columns::NSUInteger
5570
rowBytes:rowBytes::NSUInteger
56-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
71+
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
5772
obj = MPSMatrixDescriptor(desc)
5873
# XXX: who releases this object?
5974
return obj
@@ -65,7 +80,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat
6580
matrices:matrices::NSUInteger
6681
rowBytes:rowBytes::NSUInteger
6782
matrixBytes:matrixBytes::NSUInteger
68-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
83+
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
6984
obj = MPSMatrixDescriptor(desc)
7085
# XXX: who releases this object?
7186
return obj

lib/mps/vector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212

1313
function MPSVectorDescriptor(length, dataType)
1414
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
15-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
15+
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
1616
obj = MPSVectorDescriptor(desc)
1717
# XXX: who releases this object?
1818
return obj
@@ -22,7 +22,7 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType)
2222
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
2323
vectors:vectors::NSUInteger
2424
vectorBytes:vectorBytes::NSUInteger
25-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
25+
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
2626
obj = MPSVectorDescriptor(desc)
2727
# XXX: who releases this object?
2828
return obj

src/Metal.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,8 @@ include("MetalKernels.jl")
6767
import .MetalKernels: MetalBackend
6868
export MetalBackend
6969

70+
@static if !isdefined(Base, :get_extension)
71+
include("../ext/BFloat16sExt.jl")
72+
end
73+
7074
end # module

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
34
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
45
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
56
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ try
245245

246246
# catch timeouts
247247
pid = remotecall_fetch(getpid, wrkr)
248-
timer = Timer(360) do _
248+
timer = Timer(480) do _
249249
@warn "Test timed out: $test"
250250
t1 = rmprocs(wrkr, waitfor=0)
251251

0 commit comments

Comments
 (0)