11#
22# matrix enums
33#
4-
5- @cenum MPSDataType:: UInt32 begin
4+ @cenum MPSDataTypeBits:: UInt32 begin
65 MPSDataTypeComplexBit = UInt32 (0x01000000 )
76 MPSDataTypeFloatBit = UInt32 (0x10000000 )
87 MPSDataTypeSignedBit = UInt32 (0x20000000 )
98 MPSDataTypeNormalizedBit = UInt32 (0x40000000 )
109 MPSDataTypeAlternateEncodingBit = UInt32 (0x80000000 )
1110end
11+
12+ @enum MPSDataType:: UInt32 begin
13+ MPSDataTypeInvalid = UInt32 (0 )
14+
15+ MPSDataTypeUInt8 = UInt32 (8 )
16+ MPSDataTypeUInt16 = UInt32 (16 )
17+ MPSDataTypeUInt32 = UInt32 (32 )
18+ MPSDataTypeUInt64 = UInt32 (64 )
19+
20+ MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32 (8 )
21+ MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32 (16 )
22+ MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32 (32 )
23+ MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32 (64 )
24+
25+ MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32 (16 )
26+ MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32 (32 )
27+
28+ MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32 (16 )
29+ MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32 (32 )
30+
31+ MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32 (1 )
32+ MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32 (8 )
33+
34+ MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32 (8 )
35+ MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32 (16 )
36+ end
1237# # bitwise operations lose type information, so allow conversions
1338Base. convert (:: Type{MPSDataType} , x:: Integer ) = MPSDataType (x)
1439
40+ # Conversions for MPSDataTypes with Julia equivalents
41+ const jl_mps_to_typ = Dict {MPSDataType, DataType} ()
42+ for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
43+ @eval Base. convert (:: Type{MPSDataType} , :: Type{$type} ) = $ (Symbol (:MPSDataType , type))
44+ @eval jl_mps_to_typ[$ (Symbol (:MPSDataType , type))] = $ type
45+ end
46+
47+ Base. convert (:: Type{DataType} , mpstyp:: MPSDataType ) = jl_mps_to_typ[mpstyp]
48+
49+
1550#
1651# matrix descriptor
1752#
@@ -29,31 +64,11 @@ export MPSMatrixDescriptor
2964 @autoproperty matrixBytes:: NSUInteger
3065end
3166
32-
33- # Mapping from Julia types to the Performance Shader bitfields
34- const jl_typ_to_mps = Dict {DataType,MPSDataType} (
35- UInt8 => UInt32 (8 ),
36- UInt16 => UInt32 (16 ),
37- UInt32 => UInt32 (32 ),
38- UInt64 => UInt32 (64 ),
39-
40- Int8 => MPSDataTypeSignedBit | UInt32 (8 ),
41- Int16 => MPSDataTypeSignedBit | UInt32 (16 ),
42- Int32 => MPSDataTypeSignedBit | UInt32 (32 ),
43- Int64 => MPSDataTypeSignedBit | UInt32 (64 ),
44-
45- Float16 => MPSDataTypeFloatBit | UInt32 (16 ),
46- Float32 => MPSDataTypeFloatBit | UInt32 (32 ),
47-
48- ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32 (16 ),
49- ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32 (32 )
50- )
51-
5267function MPSMatrixDescriptor (rows, columns, rowBytes, dataType)
5368 desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows: rows:: NSUInteger
5469 columns: columns:: NSUInteger
5570 rowBytes: rowBytes:: NSUInteger
56- dataType: jl_typ_to_mps[ dataType] :: MPSDataType ]:: id{MPSMatrixDescriptor}
71+ dataType: dataType:: MPSDataType ]:: id{MPSMatrixDescriptor}
5772 obj = MPSMatrixDescriptor (desc)
5873 # XXX : who releases this object?
5974 return obj
@@ -65,7 +80,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat
6580 matrices: matrices:: NSUInteger
6681 rowBytes: rowBytes:: NSUInteger
6782 matrixBytes: matrixBytes:: NSUInteger
68- dataType: jl_typ_to_mps[ dataType] :: MPSDataType ]:: id{MPSMatrixDescriptor}
83+ dataType: dataType:: MPSDataType ]:: id{MPSMatrixDescriptor}
6984 obj = MPSMatrixDescriptor (desc)
7085 # XXX : who releases this object?
7186 return obj
0 commit comments