Skip to content

Commit c2207d2

Browse files
Fix copy tests (#493)
1 parent d14426c commit c2207d2

File tree

2 files changed

+37
-29
lines changed

2 files changed

+37
-29
lines changed

lib/mps/matrix.jl

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,54 @@
11
## enums
22

33
@cenum MPSDataTypeBits::UInt32 begin
4-
MPSDataTypeComplexBit = UInt32(0x01000000)
54
MPSDataTypeFloatBit = UInt32(0x10000000)
5+
MPSDataTypeComplexBit = UInt32(0x01000000)
66
MPSDataTypeSignedBit = UInt32(0x20000000)
7-
MPSDataTypeNormalizedBit = UInt32(0x40000000)
7+
MPSDataTypeIntBit = UInt32(0x20000000)
88
MPSDataTypeAlternateEncodingBit = UInt32(0x80000000)
9+
MPSDataTypeNormalizedBit = UInt32(0x40000000)
910
end
1011

11-
@enum MPSDataType::UInt32 begin
12-
MPSDataTypeInvalid = UInt32(0)
12+
@cenum MPSDataType::UInt32 begin
13+
MPSDataTypeInvalid = UInt32(0)
1314

14-
MPSDataTypeUInt8 = UInt32(8)
15-
MPSDataTypeUInt16 = UInt32(16)
16-
MPSDataTypeUInt32 = UInt32(32)
17-
MPSDataTypeUInt64 = UInt32(64)
15+
MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32)
16+
MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16)
1817

19-
MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8)
20-
MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16)
21-
MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32)
22-
MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64)
18+
MPSDataTypeComplexFloat32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(64)
19+
MPSDataTypeComplexFloat16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
2320

24-
MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16)
25-
MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32)
21+
MPSDataTypeInt4 = MPSDataTypeSignedBit | UInt32(4)
22+
MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8)
23+
MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16)
24+
MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32)
25+
MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64)
2626

27-
MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16)
28-
MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
27+
MPSDataTypeUInt4 = UInt32(4)
28+
MPSDataTypeUInt8 = UInt32(8)
29+
MPSDataTypeUInt16 = UInt32(16)
30+
MPSDataTypeUInt32 = UInt32(32)
31+
MPSDataTypeUInt64 = UInt32(64)
2932

30-
MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1)
31-
MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8)
33+
MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8)
34+
MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16)
3235

33-
MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8)
34-
MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16)
36+
MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1)
37+
MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8)
3538
end
3639
## bitwise operations lose type information, so allow conversions
3740
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
3841

3942
# Conversions for MPSDataTypes with Julia equivalents
4043
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
41-
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
42-
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
43-
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
44+
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,(ComplexF16,:MPSDataTypeComplexFloat16),(ComplexF32,:MPSDataTypeComplexFloat32),Bool]
45+
jltype, mpstype = if type isa Type
46+
type, Symbol(:MPSDataType, type)
47+
else
48+
type
49+
end
50+
@eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype)
51+
@eval jl_mps_to_typ[$(mpstype)] = $jltype
4452
end
4553
Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t])
4654

test/mps/copy.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# XXX: Why 64-bit Integers broken? Same behaviour with Swift
2-
const IGNORE_UNION = Union{Complex, Int64, UInt64}
2+
copy_is_broken(T) = sizeof(T) >= 8
33

44
function copytest(src, srctrans, dsttrans)
55
dev = device()
66
queue = global_queue(dev)
77
dst = if srctrans == dsttrans
88
similar(src)
99
else
10-
similar(src')
10+
similar(transpose(src))
1111
end
1212

1313
if dsttrans
@@ -34,15 +34,15 @@ end
3434
srcMat = MtlArray(rand(T, dim))
3535

3636
dstMat = copytest(srcMat, false, false)
37-
@test dstMat == srcMat broken=(T <: IGNORE_UNION)
37+
@test dstMat == srcMat broken=copy_is_broken(T)
3838

3939
dstMat = copytest(srcMat, true, false)
40-
@test dstMat == srcMat' broken=(T <: IGNORE_UNION)
40+
@test dstMat == transpose(srcMat) broken=copy_is_broken(T)
4141

4242
dstMat = copytest(srcMat, false, true)
43-
@test dstMat == srcMat' broken=(T <: IGNORE_UNION)
43+
@test dstMat == transpose(srcMat) broken=copy_is_broken(T)
4444

4545
dstMat = copytest(srcMat, true, true)
46-
@test dstMat == srcMat broken=(T <: IGNORE_UNION)
46+
@test dstMat == srcMat broken=copy_is_broken(T)
4747
end
4848
end

0 commit comments

Comments
 (0)