Skip to content

Commit 558ccfa

Browse files
committed
Improvements to MPSDataType
1 parent bc2131e commit 558ccfa

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

lib/mps/matrix.jl

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,56 @@
11
#
22
# matrix enums
33
#
4-
5-
@cenum MPSDataType::UInt32 begin
4+
@cenum MPSDataTypeBits::UInt32 begin
65
MPSDataTypeComplexBit = UInt32(0x01000000)
76
MPSDataTypeFloatBit = UInt32(0x10000000)
87
MPSDataTypeSignedBit = UInt32(0x20000000)
98
MPSDataTypeNormalizedBit = UInt32(0x40000000)
109
MPSDataTypeAlternateEncodingBit = UInt32(0x80000000)
1110
end
11+
12+
@enum MPSDataType::UInt32 begin
13+
MPSDataTypeInvalid = UInt32(0)
14+
15+
MPSDataTypeUInt8 = UInt32(8)
16+
MPSDataTypeUInt16 = UInt32(16)
17+
MPSDataTypeUInt32 = UInt32(32)
18+
MPSDataTypeUInt64 = UInt32(64)
19+
20+
MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8)
21+
MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16)
22+
MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32)
23+
MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64)
24+
25+
MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16)
26+
MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32)
27+
28+
MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16)
29+
MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
30+
31+
MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1)
32+
MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8)
33+
34+
MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8)
35+
MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16)
36+
end
1237
## bitwise operations lose type information, so allow conversions
1338
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
1439

40+
# Conversions for MPSDataTypes with Julia equivalents
41+
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
42+
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
43+
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
44+
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
45+
end
46+
# BFloat is only supported in MPS starting in MacOS 14
47+
if macos_version() >= v"14" && isdefined(Core, :BFloat16)
48+
Base.convert(::Type{MPSDataType}, ::Type{Core.BFloat16}) = MPSDataTypeBFloat16
49+
jl_mps_to_typ[MPSDataTypeBFloat16] = Core.BFloat16
50+
end
51+
Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]
52+
53+
1554
#
1655
# matrix descriptor
1756
#
@@ -29,31 +68,11 @@ export MPSMatrixDescriptor
2968
@autoproperty matrixBytes::NSUInteger
3069
end
3170

32-
33-
# Mapping from Julia types to the Performance Shader bitfields
34-
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
35-
UInt8 => UInt32(8),
36-
UInt16 => UInt32(16),
37-
UInt32 => UInt32(32),
38-
UInt64 => UInt32(64),
39-
40-
Int8 => MPSDataTypeSignedBit | UInt32(8),
41-
Int16 => MPSDataTypeSignedBit | UInt32(16),
42-
Int32 => MPSDataTypeSignedBit | UInt32(32),
43-
Int64 => MPSDataTypeSignedBit | UInt32(64),
44-
45-
Float16 => MPSDataTypeFloatBit | UInt32(16),
46-
Float32 => MPSDataTypeFloatBit | UInt32(32),
47-
48-
ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16),
49-
ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
50-
)
51-
5271
function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
5372
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
5473
columns:columns::NSUInteger
5574
rowBytes:rowBytes::NSUInteger
56-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
75+
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
5776
obj = MPSMatrixDescriptor(desc)
5877
# XXX: who releases this object?
5978
return obj
@@ -65,7 +84,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat
6584
matrices:matrices::NSUInteger
6685
rowBytes:rowBytes::NSUInteger
6786
matrixBytes:matrixBytes::NSUInteger
68-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
87+
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
6988
obj = MPSMatrixDescriptor(desc)
7089
# XXX: who releases this object?
7190
return obj

lib/mps/vector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212

1313
function MPSVectorDescriptor(length, dataType)
1414
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
15-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
15+
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
1616
obj = MPSVectorDescriptor(desc)
1717
# XXX: who releases this object?
1818
return obj
@@ -22,7 +22,7 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType)
2222
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
2323
vectors:vectors::NSUInteger
2424
vectorBytes:vectorBytes::NSUInteger
25-
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
25+
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
2626
obj = MPSVectorDescriptor(desc)
2727
# XXX: who releases this object?
2828
return obj

0 commit comments

Comments
 (0)