|
1 | 1 | ## enums |
2 | 2 |
|
3 | 3 | @cenum MPSDataTypeBits::UInt32 begin |
4 | | - MPSDataTypeComplexBit = UInt32(0x01000000) |
5 | 4 | MPSDataTypeFloatBit = UInt32(0x10000000) |
| 5 | + MPSDataTypeComplexBit = UInt32(0x01000000) |
6 | 6 | MPSDataTypeSignedBit = UInt32(0x20000000) |
7 | | - MPSDataTypeNormalizedBit = UInt32(0x40000000) |
| 7 | + MPSDataTypeIntBit = UInt32(0x20000000) |
8 | 8 | MPSDataTypeAlternateEncodingBit = UInt32(0x80000000) |
| 9 | + MPSDataTypeNormalizedBit = UInt32(0x40000000) |
9 | 10 | end |
10 | 11 |
|
11 | | -@enum MPSDataType::UInt32 begin |
12 | | - MPSDataTypeInvalid = UInt32(0) |
| 12 | +@cenum MPSDataType::UInt32 begin |
| 13 | + MPSDataTypeInvalid = UInt32(0) |
13 | 14 |
|
14 | | - MPSDataTypeUInt8 = UInt32(8) |
15 | | - MPSDataTypeUInt16 = UInt32(16) |
16 | | - MPSDataTypeUInt32 = UInt32(32) |
17 | | - MPSDataTypeUInt64 = UInt32(64) |
| 15 | + MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32) |
| 16 | + MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16) |
18 | 17 |
|
19 | | - MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8) |
20 | | - MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16) |
21 | | - MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32) |
22 | | - MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64) |
| 18 | + MPSDataTypeComplexFloat32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(64) |
| 19 | + MPSDataTypeComplexFloat16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32) |
23 | 20 |
|
24 | | - MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16) |
25 | | - MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32) |
| 21 | + MPSDataTypeInt4 = MPSDataTypeSignedBit | UInt32(4) |
| 22 | + MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8) |
| 23 | + MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16) |
| 24 | + MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32) |
| 25 | + MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64) |
26 | 26 |
|
27 | | - MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16) |
28 | | - MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32) |
| 27 | + MPSDataTypeUInt4 = UInt32(4) |
| 28 | + MPSDataTypeUInt8 = UInt32(8) |
| 29 | + MPSDataTypeUInt16 = UInt32(16) |
| 30 | + MPSDataTypeUInt32 = UInt32(32) |
| 31 | + MPSDataTypeUInt64 = UInt32(64) |
29 | 32 |
|
30 | | - MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1) |
31 | | - MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8) |
| 33 | + MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8) |
| 34 | + MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16) |
32 | 35 |
|
33 | | - MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8) |
34 | | - MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16) |
| 36 | + MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1) |
| 37 | + MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8) |
35 | 38 | end |
36 | 39 | ## bitwise operations lose type information, so allow conversions |
37 | 40 | Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) |
38 | 41 |
|
39 | 42 | # Conversions for MPSDataTypes with Julia equivalents |
40 | 43 | const jl_mps_to_typ = Dict{MPSDataType, DataType}() |
41 | | -for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool] |
42 | | - @eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type)) |
43 | | - @eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type |
| 44 | +for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,(ComplexF16,:MPSDataTypeComplexFloat16),(ComplexF32,:MPSDataTypeComplexFloat32),Bool] |
| 45 | + jltype, mpstype = if type isa Type |
| 46 | + type, Symbol(:MPSDataType, type) |
| 47 | + else |
| 48 | + type |
| 49 | + end |
| 50 | + @eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype) |
| 51 | + @eval jl_mps_to_typ[$(mpstype)] = $jltype |
44 | 52 | end |
45 | 53 | Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t]) |
46 | 54 |
|
|
0 commit comments