Skip to content

Commit d8370f7

Browse files
committed
Move linalg wrappers out of MPS lib
1 parent e3b6210 commit d8370f7

File tree

6 files changed

+105
-85
lines changed

6 files changed

+105
-85
lines changed

lib/mps/MPS.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ const MtlFloat = Union{Float32, Float16}
2323
const MPSShape = NSArray#{NSNumber}
2424
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))
2525

26+
# Valid combination of input (A and B matrices) and output (C) types
27+
const MPS_VALID_MATMUL_TYPES =
28+
[(Int8, Float16),
29+
(Int8, Float32),
30+
(Int16, Float32),
31+
(Float16, Float16),
32+
(Float16, Float32),
33+
(Float32, Float32)]
34+
35+
const MPS_VALID_MATVECMUL_TYPES =
36+
[(Float16, Float16),
37+
(Float16, Float32),
38+
(Float32, Float32)]
39+
2640
is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)
2741

2842
# Load in generated enums and structs
@@ -43,6 +57,5 @@ include("copy.jl")
4357

4458
# integrations
4559
include("random.jl")
46-
include("linalg.jl")
4760

4861
end

lib/mps/command_buf.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# @objcwrapper MPSCommandBuffer <: MTLCommandBuffer
88

9+
export MPSCommandBuffer
10+
911
function MPSCommandBuffer(commandBuffer::MTLCommandBuffer)
1012
handle = @objc [MPSCommandBuffer commandBufferWithCommandBuffer:commandBuffer::id{MTLCommandBuffer}]::id{MPSCommandBuffer}
1113
MPSCommandBuffer(handle)
@@ -32,6 +34,8 @@ function MTL.commit!(f::Base.Callable, cmdbuf::MPSCommandBuffer)
3234
return ret
3335
end
3436

37+
export commitAndContinue!
38+
3539
commitAndContinue!(cmdbuf::MPSCommandBuffer) =
3640
@objc [cmdbuf::id{MPSCommandBuffer} commitAndContinue]::Nothing
3741

src/Metal.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ include("compiler/reflection.jl")
5353
include("../lib/mps/MPS.jl")
5454
export MPS
5555

56+
# LinearAlgebra
57+
include("linalg.jl")
58+
5659
# array implementation
5760
include("utilities.jl")
5861
include("broadcast.jl")

lib/mps/linalg.jl renamed to src/linalg.jl

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
using LinearAlgebra
22
using LinearAlgebra: MulAddMul, wrap
3-
4-
# Valid combination of input (A and B matrices) and output (C) types
5-
const MPS_VALID_MATMUL_TYPES =
6-
[(Int8, Float16),
7-
(Int8, Float32),
8-
(Int16, Float32),
9-
(Float16, Float16),
10-
(Float32, Float32)]
3+
using .MPS
4+
using .MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
115

126
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
137
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
@@ -39,19 +33,14 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
3933
typC = eltype(C)
4034

4135
# If possible, dispatch to performance shaders
42-
if is_supported(device()) &&
43-
typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
36+
if MPS.is_supported(device()) &&
37+
typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
4438
matmul!(C, A, B, alpha, beta, transA, transB)
4539
else
4640
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
4741
end
4842
end
4943

50-
const MPS_VALID_MATVECMUL_TYPES =
51-
[(Float16, Float16),
52-
(Float16, Float32),
53-
(Float32, Float32)]
54-
5544
LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B::MtlVector, _add::MulAddMul) =
5645
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
5746
@autoreleasepool function LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar,
@@ -82,24 +71,24 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
8271
typC = eltype(C)
8372

8473
# If possible, dispatch to performance shaders
85-
if is_supported(device()) &&
86-
typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
74+
if MPS.is_supported(device()) &&
75+
typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
8776
matvecmul!(C, A, B, alpha, beta, transA)
8877
else
8978
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)
9079
end
9180
end
9281

9382
@inline checkpositivedefinite(status) =
94-
status == MPSMatrixDecompositionStatusNonPositiveDefinite || throw(PosDefException(status))
83+
status == MPS.MPSMatrixDecompositionStatusNonPositiveDefinite || throw(PosDefException(status))
9584
@inline checknonsingular(status) =
96-
status != MPSMatrixDecompositionStatusSingular || throw(SingularException(status))
85+
status != MPS.MPSMatrixDecompositionStatusSingular || throw(SingularException(status))
9786

9887
# GPU-compatible accessors of the LU decomposition properties
99-
function Base.getproperty(F::LU{T,<:MtlMatrix}, d::Symbol) where T
88+
function Base.getproperty(F::LU{T, <:MtlMatrix}, d::Symbol) where {T}
10089
m, n = size(F)
10190
if d === :L
102-
L = tril!(getfield(F, :factors)[1:m, 1:min(m,n)])
91+
L = tril!(getfield(F, :factors)[1:m, 1:min(m, n)])
10392
L[1:m+1:end] .= one(T)
10493
return L
10594
else
@@ -111,16 +100,16 @@ end
111100
# TODO: figure out a GPU-compatible way to get the permutation matrix
112101
LinearAlgebra.ipiv2perm(v::MtlVector, maxi::Integer) =
113102
LinearAlgebra.ipiv2perm(Array(v), maxi)
114-
LinearAlgebra.ipiv2perm(v::MtlVector{<:Any,MTL.CPUStorage}, maxi::Integer) =
103+
LinearAlgebra.ipiv2perm(v::MtlVector{<:Any, MTL.CPUStorage}, maxi::Integer) =
115104
LinearAlgebra.ipiv2perm(unsafe_wrap(Array, v), maxi)
116105

117106
@autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T};
118-
check::Bool=true) where {T<:MtlFloat}
119-
M,N = size(A)
107+
check::Bool = true) where {T <: MtlFloat}
108+
M, N = size(A)
120109
dev = device()
121110
queue = global_queue(dev)
122111

123-
At = MtlMatrix{T,PrivateStorage}(undef, (N, M))
112+
At = MtlMatrix{T, PrivateStorage}(undef, (N, M))
124113
mps_a = MPSMatrix(A)
125114
mps_at = MPSMatrix(At)
126115

@@ -131,7 +120,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{<:Any,MTL.CPUStorage}, maxi::Integer) =
131120
end
132121

133122
P = similar(A, UInt32, 1, min(N, M))
134-
status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef)
123+
status = MtlArray{MPS.MPSMatrixDecompositionStatus, 0, SharedStorage}(undef)
135124

136125
commitAndContinue!(cmdbuf) do cbuf
137126
mps_p = MPSMatrix(P)
@@ -172,13 +161,13 @@ end
172161

173162
# TODO: dispatch on pivot strategy
174163
@autoreleasepool function LinearAlgebra.lu!(A::MtlMatrix{T};
175-
check::Bool=true,
176-
allowsingular::Bool=false) where {T<:MtlFloat}
177-
M,N = size(A)
164+
check::Bool = true,
165+
allowsingular::Bool = false) where {T <: MtlFloat}
166+
M, N = size(A)
178167
dev = device()
179168
queue = global_queue(dev)
180169

181-
At = MtlMatrix{T,PrivateStorage}(undef, (N, M))
170+
At = MtlMatrix{T, PrivateStorage}(undef, (N, M))
182171
mps_a = MPSMatrix(A)
183172
mps_at = MPSMatrix(At)
184173

@@ -189,7 +178,7 @@ end
189178
end
190179

191180
P = similar(A, UInt32, 1, min(N, M))
192-
status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef)
181+
status = MtlArray{MPS.MPSMatrixDecompositionStatus, 0, SharedStorage}(undef)
193182

194183
commitAndContinue!(cmdbuf) do cbuf
195184
mps_p = MPSMatrix(P)
@@ -215,9 +204,9 @@ end
215204

216205
@autoreleasepool function LinearAlgebra.transpose!(B::MtlMatrix{T},
217206
A::MtlMatrix{T}) where {T}
218-
axes(B,2) == axes(A,1) && axes(B,1) == axes(A,2) || throw(DimensionMismatch("transpose"))
207+
axes(B, 2) == axes(A, 1) && axes(B, 1) == axes(A, 2) || throw(DimensionMismatch("transpose"))
219208

220-
M,N = size(A)
209+
M, N = size(A)
221210
dev = device()
222211
queue = global_queue(dev)
223212
cmdbuf = MTLCommandBuffer(queue)

test/linalg.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using LinearAlgebra
2+
3+
if MPS.is_supported(device())
4+
5+
6+
@testset "test matrix vector multiplication of views" begin
7+
N = 20
8+
9+
a = rand(Float32, N, N)
10+
b = rand(Float32, N)
11+
c = a * b
12+
13+
mtl_a = mtl(a)
14+
mtl_b = mtl(b)
15+
mtl_c = mtl_a * mtl_b
16+
17+
@test Array(mtl_c) c
18+
19+
view_a = @view a[:, 10:end]
20+
view_b = @view b[10:end]
21+
22+
mtl_view_a = @view mtl_a[:, 10:end]
23+
mtl_view_b = @view mtl_b[10:end]
24+
25+
mtl_view_c = mtl_view_a * mtl_view_b
26+
view_c = view_a * view_b
27+
28+
@test Array(mtl_view_c) == view_c
29+
end
30+
31+
using Metal: storagemode
32+
@testset "decompositions" begin
33+
A = MtlMatrix(rand(Float32, 1024, 1024))
34+
lua = lu(A)
35+
@test lua.L * lua.U MtlMatrix(lua.P) * A
36+
37+
A = MtlMatrix(rand(Float32, 1024, 512))
38+
lua = lu(A)
39+
@test lua.L * lua.U MtlMatrix(lua.P) * A
40+
41+
A = MtlMatrix(rand(Float32, 512, 1024))
42+
lua = lu(A)
43+
@test lua.L * lua.U MtlMatrix(lua.P) * A
44+
45+
a = rand(Float32, 1024, 1024)
46+
A = MtlMatrix(a)
47+
B = MtlMatrix(a)
48+
lua = lu!(A)
49+
@test lua.L * lua.U MtlMatrix(lua.P) * B
50+
51+
A = MtlMatrix{Float32}([1 2; 0 0])
52+
@test_throws SingularException lu(A)
53+
54+
altStorage = Metal.DefaultStorageMode != Metal.PrivateStorage ? Metal.PrivateStorage : Metal.SharedStorage
55+
A = MtlMatrix{Float32, altStorage}(rand(Float32, 1024, 1024))
56+
lua = lu(A)
57+
@test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A)
58+
lua = lu!(A)
59+
@test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A)
60+
end
61+
62+
end

test/mps/linalg.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -69,26 +69,6 @@ end
6969
end
7070
end
7171

72-
@testset "test matrix vector multiplication of views" begin
73-
N = 20
74-
a = rand(Float32, N,N)
75-
b = rand(Float32, N)
76-
77-
mtl_a = mtl(a)
78-
mtl_b = mtl(b)
79-
80-
view_a = @view a[:,10:end]
81-
view_b = @view b[10:end]
82-
83-
mtl_view_a = @view mtl_a[:,10:end]
84-
mtl_view_b = @view mtl_b[10:end]
85-
86-
mtl_c = mtl_view_a * mtl_view_b
87-
c = view_a * view_b
88-
89-
@test Array(mtl_c) == c
90-
end
91-
9272
@testset "mixed-precision matrix vector multiplication" begin
9373
N = 10
9474
rows = N
@@ -180,37 +160,6 @@ end
180160
end
181161
end
182162

183-
using Metal: storagemode
184-
@testset "decompositions" begin
185-
A = MtlMatrix(rand(Float32, 1024, 1024))
186-
lua = lu(A)
187-
@test lua.L * lua.U MtlMatrix(lua.P) * A
188-
189-
A = MtlMatrix(rand(Float32, 1024, 512))
190-
lua = lu(A)
191-
@test lua.L * lua.U MtlMatrix(lua.P) * A
192-
193-
A = MtlMatrix(rand(Float32, 512, 1024))
194-
lua = lu(A)
195-
@test lua.L * lua.U MtlMatrix(lua.P) * A
196-
197-
a = rand(Float32, 1024, 1024)
198-
A = MtlMatrix(a)
199-
B = MtlMatrix(a)
200-
lua = lu!(A)
201-
@test lua.L * lua.U MtlMatrix(lua.P) * B
202-
203-
A = MtlMatrix{Float32}([1 2; 0 0])
204-
@test_throws SingularException lu(A)
205-
206-
altStorage = Metal.DefaultStorageMode != Metal.PrivateStorage ? Metal.PrivateStorage : Metal.SharedStorage
207-
A = MtlMatrix{Float32,altStorage}(rand(Float32, 1024, 1024))
208-
lua = lu(A)
209-
@test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A)
210-
lua = lu!(A)
211-
@test storagemode(lua.factors) == storagemode(lua.ipiv) == storagemode(A)
212-
end
213-
214163
using .MPS: MPSMatrixSoftMax, MPSMatrixLogSoftMax
215164
@testset "MPSMatrixSoftMax" begin
216165
cols = rand(Int)

0 commit comments

Comments
 (0)