Skip to content

Commit 8d4dd5d

Browse files
Code coverage and misc fixes (#569)
* NDArray construction bug fix * `size` for MPSMatrix * Add and fix MPSMatrix tests * Don't track coverage for examples * Exclude device-only code from coverage * Exclude profiler tests from coverage * format * Struct coverage * MTLDevice tests * Move MPSDataType out of matrix file and test * Storage type tests * Format * Fix * Update lib/mps/matrix.jl Co-authored-by: Tim Besard <[email protected]> * Remove unused code and tests --------- Co-authored-by: Tim Besard <[email protected]>
1 parent b7606f0 commit 8d4dd5d

File tree

13 files changed

+230
-124
lines changed

13 files changed

+230
-124
lines changed

.buildkite/pipeline.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ steps:
1414
dirs:
1515
- src
1616
- lib
17-
- examples
1817
agents:
1918
queue: "juliaecosystem"
2019
os: "macos"
@@ -84,7 +83,6 @@ steps:
8483
dirs:
8584
- src
8685
- lib
87-
- examples
8886
env:
8987
MTL_DEBUG_LAYER: '1'
9088
MTL_SHADER_VALIDATION: '1'
@@ -113,7 +111,6 @@ steps:
113111
dirs:
114112
- src
115113
- lib
116-
- examples
117114
env:
118115
JULIA_LLVM_ARGS: '--opaque-pointers'
119116
agents:

lib/mps/MPS.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice}
4343
include("libmps.jl")
4444

4545
include("size.jl")
46+
include("datatype.jl")
4647

4748
# high-level wrappers
4849
include("command_buf.jl")

lib/mps/datatype.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
## Some extra definitions for MPSDataType defined in libmps.jl
2+
3+
# Conversions for MPSDataTypes with Julia equivalents
4+
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
5+
for type in [
6+
:Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64,
7+
:Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16),
8+
(:ComplexF32, :MPSDataTypeComplexFloat32),
9+
]
10+
jltype, mpstype = if type isa Symbol
11+
type, Symbol(:MPSDataType, type)
12+
else
13+
type
14+
end
15+
@eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype)
16+
@eval jl_mps_to_typ[$(mpstype)] = $jltype
17+
end

lib/mps/matrix.jl

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,3 @@
1-
## Some extra definitions for MPSDataType defined in libmps.jl
2-
3-
## bitwise operations lose type information, so allow conversions
4-
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
5-
6-
# Conversions for MPSDataTypes with Julia equivalents
7-
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
8-
for type in [
9-
:Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64,
10-
:Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16),
11-
(:ComplexF32, :MPSDataTypeComplexFloat32),
12-
]
13-
jltype, mpstype = if type isa Symbol
14-
type, Symbol(:MPSDataType, type)
15-
else
16-
type
17-
end
18-
@eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype)
19-
@eval jl_mps_to_typ[$(mpstype)] = $jltype
20-
end
21-
Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t])
22-
23-
Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]
24-
25-
261
## descriptor
272

283
export MPSMatrixDescriptor
@@ -119,6 +94,13 @@ function MPSMatrix(arr::MtlArray{T,3}) where T
11994
return MPSMatrix(arr, desc, offset)
12095
end
12196

97+
function Base.size(mat::MPS.MPSMatrix)
98+
if mat.matrices > 1
99+
return Int.((mat.matrices, mat.rows, mat.columns))
100+
else
101+
return Int.((mat.rows, mat.columns))
102+
end
103+
end
122104

123105
## matrix multiplication
124106

@@ -160,7 +142,7 @@ with any `MtlArray` and it should be accelerated using Metal Performance Shaders
160142
"""
161143
function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N},
162144
alpha::Number=true, beta::Number=true,
163-
transpose_a=false, transpose_b=false) where {T1, T2, T3, N}
145+
transpose_a=false, transpose_b=false) where {T1, T2, T3, N}
164146
# NOTE: MPS uses row major, while Julia is col-major. Instead of transposing
165147
# the inputs (by passing !transpose_[ab]) and afterwards transposing
166148
# the output, we use the property that (AB)ᵀ = BᵀAᵀ

lib/mps/ndarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
114114
arrsize = size(arr)
115115
@assert arrsize[1] * sizeof(T) % 16 == 0 "First dimension of input MtlArray must have a byte size divisible by 16"
116116
desc = MPSNDArrayDescriptor(T, arrsize)
117-
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
117+
return MPSNDArray(arr.data[], UInt(arr.offset) * sizeof(T), desc)
118118
end
119119

120120
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)

src/MetalKernels.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ end
131131

132132
## indexing
133133

134+
## COV_EXCL_START
134135
@device_override @inline function KA.__index_Local_Linear(ctx)
135136
return thread_position_in_threadgroup_1d()
136137
end
@@ -191,5 +192,6 @@ end
191192
@device_override @inline function KA.__print(args...)
192193
# TODO
193194
end
195+
## COV_EXCL_STOP
194196

195197
end

src/accumulate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
## COV_EXCL_START
12
function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray,
23
Rdim, Rpre, Rpost, Rother, neutral, init,
34
::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive}
@@ -100,6 +101,7 @@ function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates:
100101

101102
return
102103
end
104+
## COV_EXCL_STOP
103105

104106
function scan!(f::Function, output::WrappedMtlArray{T}, input::WrappedMtlArray;
105107
dims::Integer, init=nothing, neutral=GPUArrays.neutral_element(f, T)) where {T}

src/broadcast.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ end
5959
_broadcast_shapes[Is] += 1
6060
end
6161
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
62+
## COV_EXCL_START
6263
function broadcast_cartesian_static(dest, bc, Is)
6364
i = thread_position_in_grid_1d()
6465
stride = threads_per_grid_1d()
@@ -69,6 +70,7 @@ end
6970
end
7071
return
7172
end
73+
## COV_EXCL_STOP
7274

7375
Is = StaticCartesianIndices(Is)
7476
kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is)
@@ -82,6 +84,7 @@ end
8284
# try to use the most appropriate hardware index to avoid integer division
8385
if ndims(dest) == 1 ||
8486
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
87+
## COV_EXCL_START
8588
function broadcast_linear(dest, bc)
8689
i = thread_position_in_grid_1d()
8790
stride = threads_per_grid_1d()
@@ -91,12 +94,14 @@ end
9194
end
9295
return
9396
end
97+
## COV_EXCL_STOP
9498

9599
kernel = @metal launch=false broadcast_linear(dest, bc)
96100
elements = cld(length(dest), 4)
97101
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
98102
groups = cld(elements, threads)
99103
elseif ndims(dest) == 2
104+
## COV_EXCL_START
100105
function broadcast_2d(dest, bc)
101106
is = Tuple(thread_position_in_grid_2d())
102107
stride = threads_per_grid_2d()
@@ -107,13 +112,15 @@ end
107112
end
108113
return
109114
end
115+
## COV_EXCL_STOP
110116

111117
kernel = @metal launch=false broadcast_2d(dest, bc)
112118
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
113119
h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
114120
threads = (w, h)
115121
groups = cld.(size(dest), threads)
116122
elseif ndims(dest) == 3
123+
## COV_EXCL_START
117124
function broadcast_3d(dest, bc)
118125
is = Tuple(thread_position_in_grid_3d())
119126
stride = threads_per_grid_3d()
@@ -126,6 +133,7 @@ end
126133
end
127134
return
128135
end
136+
## COV_EXCL_STOP
129137

130138
kernel = @metal launch=false broadcast_3d(dest, bc)
131139
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
@@ -135,6 +143,7 @@ end
135143
threads = (w, h, d)
136144
groups = cld.(size(dest), threads)
137145
else
146+
## COV_EXCL_START
138147
function broadcast_cartesian(dest, bc)
139148
i = thread_position_in_grid_1d()
140149
stride = threads_per_grid_1d()
@@ -145,6 +154,7 @@ end
145154
end
146155
return
147156
end
157+
## COV_EXCL_STOP
148158

149159
kernel = @metal launch=false broadcast_cartesian(dest, bc)
150160
elements = cld(length(dest), 4)

src/utilities.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ end
146146

147147

148148
## profile macro
149-
149+
## COV_EXCL_START
150150
function profile_dir()
151151
root = pwd()
152152
i = 1
@@ -239,3 +239,5 @@ macro profile(ex...)
239239
end
240240
end
241241
end
242+
## COV_EXCL_START
243+

test/mps/matrix.jl

Lines changed: 77 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -45,77 +45,83 @@ using .MPS: MPSMatrix
4545
rowBytes = sizeof(T) * cols
4646
mats = 4
4747

48-
desc = MPSMatrixDescriptor(rows, cols, rowBytes, T)
49-
devmat = MPSMatrix(dev, desc)
50-
@test devmat isa MPSMatrix
51-
@test devmat.device == dev
52-
@test devmat.rows == rows
53-
@test devmat.columns == cols
54-
@test devmat.rowBytes == rowBytes
55-
@test devmat.matrices == 1
56-
@test devmat.dataType == DT
57-
@test devmat.matrixBytes == rowBytes * rows
58-
@test devmat.offset == 0
59-
60-
mat = MtlMatrix{T}(undef, rows, cols)
61-
acols, arows = size(mat)
62-
arowBytes = sizeof(T) * acols
63-
abufmat = MPSMatrix(mat)
64-
@test abufmat isa MPSMatrix
65-
@test abufmat.device == dev
66-
@test abufmat.rows == arows
67-
@test abufmat.columns == acols
68-
@test abufmat.rowBytes == arowBytes
69-
@test abufmat.matrices == 1
70-
@test abufmat.dataType == DT
71-
@test abufmat.matrixBytes == arowBytes * arows
72-
@test abufmat.offset == 0
73-
@test abufmat.data == mat.data[]
74-
75-
vmat = @view mat[:, 2:3]
76-
vcols, vrows = size(vmat)
77-
vrowBytes = sizeof(T) * vcols
78-
vbufmat = MPSMatrix(vmat)
79-
@test vbufmat isa MPSMatrix
80-
@test vbufmat.device == dev
81-
@test vbufmat.rows == vrows
82-
@test vbufmat.columns == vcols
83-
@test vbufmat.rowBytes == vrowBytes
84-
@test vbufmat.matrices == 1
85-
@test vbufmat.dataType == DT
86-
@test vbufmat.matrixBytes == vrowBytes * vrows
87-
@test vbufmat.offset == vmat.offset * sizeof(T)
88-
@test vbufmat.data == vmat.data[]
89-
90-
arr = MtlArray{T,3}(undef, rows, cols, mats)
91-
mcols, mrows, mmats = size(arr)
92-
mrowBytes = sizeof(T) * mcols
93-
mpsmat = MPSMatrix(mat)
94-
@test mpsmat isa MPSMatrix
95-
@test mpsmat.device == dev
96-
@test mpsmat.rows == mrows
97-
@test mpsmat.columns == mcols
98-
@test mpsmat.rowBytes == mrowBytes
99-
@test mpsmat.matrices == 1
100-
@test mpsmat.dataType == DT
101-
@test mpsmat.matrixBytes == mrowBytes * mrows
102-
@test mpsmat.offset == 0
103-
@test mpsmat.data == mat.data[]
104-
105-
vec = MtlVector{T}(undef, rows)
106-
veccols, vecrows = length(vec), 1
107-
vecrowBytes = sizeof(T)*veccols
108-
vmpsmat = MPSMatrix(vec)
109-
@test vmpsmat isa MPSMatrix
110-
@test vmpsmat.device == dev
111-
@test vmpsmat.rows == vecrows
112-
@test vmpsmat.columns == veccols
113-
@test vmpsmat.rowBytes == vecrowBytes
114-
@test vmpsmat.matrices == 1
115-
@test vmpsmat.dataType == DT
116-
@test vmpsmat.matrixBytes == vecrowBytes*vecrows
117-
@test vmpsmat.offset == 0
118-
@test vmpsmat.data == vec.data[]
48+
let desc = MPSMatrixDescriptor(rows, cols, rowBytes, T)
49+
devmat = MPSMatrix(dev, desc)
50+
@test devmat isa MPSMatrix
51+
@test devmat.device == dev
52+
@test devmat.rows == rows
53+
@test devmat.columns == cols
54+
@test devmat.rowBytes == rowBytes
55+
@test devmat.matrices == 1
56+
@test devmat.dataType == DT
57+
@test devmat.matrixBytes == rowBytes * rows
58+
@test devmat.offset == 0
59+
@test size(devmat) == (rows, cols)
60+
end
61+
62+
let mat = MtlMatrix{T}(undef, rows, cols)
63+
acols, arows = size(mat)
64+
arowBytes = sizeof(T) * acols
65+
abufmat = MPSMatrix(mat)
66+
@test abufmat isa MPSMatrix
67+
@test abufmat.device == dev
68+
@test abufmat.rows == arows
69+
@test abufmat.columns == acols
70+
@test abufmat.rowBytes == arowBytes
71+
@test abufmat.matrices == 1
72+
@test abufmat.dataType == DT
73+
@test abufmat.matrixBytes == arowBytes * arows
74+
@test abufmat.offset == 0
75+
@test abufmat.data == mat.data[]
76+
77+
vmat = @view mat[:, 2:3]
78+
vcols, vrows = size(vmat)
79+
vrowBytes = sizeof(T) * vcols
80+
vbufmat = MPSMatrix(vmat)
81+
@test vbufmat isa MPSMatrix
82+
@test vbufmat.device == dev
83+
@test vbufmat.rows == vrows
84+
@test vbufmat.columns == vcols
85+
@test vbufmat.rowBytes == vrowBytes
86+
@test vbufmat.matrices == 1
87+
@test vbufmat.dataType == DT
88+
@test vbufmat.matrixBytes == vrowBytes * vrows
89+
@test vbufmat.offset == vmat.offset * sizeof(T)
90+
@test vbufmat.data == vmat.data[]
91+
end
92+
93+
let arr = MtlArray{T, 3}(undef, rows, cols, mats)
94+
mcols, mrows, mmats = size(arr)
95+
mrowBytes = sizeof(T) * mcols
96+
mpsmat = MPSMatrix(arr)
97+
@test mpsmat isa MPSMatrix
98+
@test mpsmat.device == dev
99+
@test mpsmat.rows == mrows
100+
@test mpsmat.columns == mcols
101+
@test mpsmat.rowBytes == mrowBytes
102+
@test mpsmat.matrices == mmats
103+
@test mpsmat.dataType == DT
104+
@test mpsmat.matrixBytes == mrowBytes * mrows
105+
@test mpsmat.offset == 0
106+
@test mpsmat.data == arr.data[]
107+
@test size(mpsmat) == (mmats, mrows, mcols)
108+
end
109+
110+
let vec = MtlVector{T}(undef, rows)
111+
veccols, vecrows = length(vec), 1
112+
vecrowBytes = sizeof(T) * veccols
113+
vmpsmat = MPSMatrix(vec)
114+
@test vmpsmat isa MPSMatrix
115+
@test vmpsmat.device == dev
116+
@test vmpsmat.rows == vecrows
117+
@test vmpsmat.columns == veccols
118+
@test vmpsmat.rowBytes == vecrowBytes
119+
@test vmpsmat.matrices == 1
120+
@test vmpsmat.dataType == DT
121+
@test vmpsmat.matrixBytes == vecrowBytes * vecrows
122+
@test vmpsmat.offset == 0
123+
@test vmpsmat.data == vec.data[]
124+
end
119125
end
120126

121127

0 commit comments

Comments
 (0)