11using LinearAlgebra
22using LinearAlgebra: MulAddMul, wrap
3-
4- # Valid combination of input (A and B matrices) and output (C) types
5- const MPS_VALID_MATMUL_TYPES =
6- [(Int8, Float16),
7- (Int8, Float32),
8- (Int16, Float32),
9- (Float16, Float16),
10- (Float32, Float32)]
3+ using . MPS
4+ using . MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
115
126LinearAlgebra. generic_matmatmul! (C:: MtlMatrix , tA, tB, A:: MtlMatrix , B:: MtlMatrix , _add:: MulAddMul ) =
137 LinearAlgebra. generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
@@ -39,19 +33,14 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
3933 typC = eltype (C)
4034
4135 # If possible, dispatch to performance shaders
42- if is_supported (device ()) &&
43- typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
36+ if MPS . is_supported (device ()) &&
37+ typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
4438 matmul! (C, A, B, alpha, beta, transA, transB)
4539 else
4640 GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
4741 end
4842end
4943
50- const MPS_VALID_MATVECMUL_TYPES =
51- [(Float16, Float16),
52- (Float16, Float32),
53- (Float32, Float32)]
54-
5544LinearAlgebra. generic_matvecmul! (C:: MtlVector , tA:: AbstractChar , A:: MtlMatrix , B:: MtlVector , _add:: MulAddMul ) =
5645 LinearAlgebra. generic_matvecmul! (C, tA, A, B, _add. alpha, _add. beta)
5746@autoreleasepool function LinearAlgebra. generic_matvecmul! (C:: MtlVector , tA:: AbstractChar ,
@@ -82,24 +71,24 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
8271 typC = eltype (C)
8372
8473 # If possible, dispatch to performance shaders
85- if is_supported (device ()) &&
86- typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
74+ if MPS . is_supported (device ()) &&
75+ typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
8776 matvecmul! (C, A, B, alpha, beta, transA)
8877 else
8978 GPUArrays. generic_matmatmul! (C, wrap (A, tA), B, alpha, beta)
9079 end
9180end
9281
9382@inline checkpositivedefinite (status) =
94- status == MPSMatrixDecompositionStatusNonPositiveDefinite || throw (PosDefException (status))
83+ status == MPS . MPSMatrixDecompositionStatusNonPositiveDefinite || throw (PosDefException (status))
9584@inline checknonsingular (status) =
96- status != MPSMatrixDecompositionStatusSingular || throw (SingularException (status))
85+ status != MPS . MPSMatrixDecompositionStatusSingular || throw (SingularException (status))
9786
9887# GPU-compatible accessors of the LU decomposition properties
99- function Base. getproperty (F:: LU{T,<:MtlMatrix} , d:: Symbol ) where T
88+ function Base. getproperty (F:: LU{T, <:MtlMatrix} , d:: Symbol ) where {T}
10089 m, n = size (F)
10190 if d === :L
102- L = tril! (getfield (F, :factors )[1 : m, 1 : min (m,n)])
91+ L = tril! (getfield (F, :factors )[1 : m, 1 : min (m, n)])
10392 L[1 : m+ 1 : end ] .= one (T)
10493 return L
10594 else
@@ -111,16 +100,16 @@ end
111100# TODO : figure out a GPU-compatible way to get the permutation matrix
112101LinearAlgebra. ipiv2perm (v:: MtlVector , maxi:: Integer ) =
113102 LinearAlgebra. ipiv2perm (Array (v), maxi)
114- LinearAlgebra. ipiv2perm (v:: MtlVector{<:Any,MTL.CPUStorage} , maxi:: Integer ) =
103+ LinearAlgebra. ipiv2perm (v:: MtlVector{<:Any, MTL.CPUStorage} , maxi:: Integer ) =
115104 LinearAlgebra. ipiv2perm (unsafe_wrap (Array, v), maxi)
116105
117106@autoreleasepool function LinearAlgebra. lu (A:: MtlMatrix{T} ;
118- check:: Bool = true ) where {T<: MtlFloat }
119- M,N = size (A)
107+ check:: Bool = true ) where {T <: MtlFloat }
108+ M, N = size (A)
120109 dev = device ()
121110 queue = global_queue (dev)
122111
123- At = MtlMatrix {T,PrivateStorage} (undef, (N, M))
112+ At = MtlMatrix {T, PrivateStorage} (undef, (N, M))
124113 mps_a = MPSMatrix (A)
125114 mps_at = MPSMatrix (At)
126115
@@ -131,7 +120,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{<:Any,MTL.CPUStorage}, maxi::Integer) =
131120 end
132121
133122 P = similar (A, UInt32, 1 , min (N, M))
134- status = MtlArray {MPSMatrixDecompositionStatus,0, SharedStorage} (undef)
123+ status = MtlArray {MPS. MPSMatrixDecompositionStatus, 0, SharedStorage} (undef)
135124
136125 commitAndContinue! (cmdbuf) do cbuf
137126 mps_p = MPSMatrix (P)
@@ -172,13 +161,13 @@ end
172161
173162# TODO : dispatch on pivot strategy
174163@autoreleasepool function LinearAlgebra. lu! (A:: MtlMatrix{T} ;
175- check:: Bool = true ,
176- allowsingular:: Bool = false ) where {T<: MtlFloat }
177- M,N = size (A)
164+ check:: Bool = true ,
165+ allowsingular:: Bool = false ) where {T <: MtlFloat }
166+ M, N = size (A)
178167 dev = device ()
179168 queue = global_queue (dev)
180169
181- At = MtlMatrix {T,PrivateStorage} (undef, (N, M))
170+ At = MtlMatrix {T, PrivateStorage} (undef, (N, M))
182171 mps_a = MPSMatrix (A)
183172 mps_at = MPSMatrix (At)
184173
189178 end
190179
191180 P = similar (A, UInt32, 1 , min (N, M))
192- status = MtlArray {MPSMatrixDecompositionStatus,0, SharedStorage} (undef)
181+ status = MtlArray {MPS. MPSMatrixDecompositionStatus, 0, SharedStorage} (undef)
193182
194183 commitAndContinue! (cmdbuf) do cbuf
195184 mps_p = MPSMatrix (P)
215204
216205@autoreleasepool function LinearAlgebra. transpose! (B:: MtlMatrix{T} ,
217206 A:: MtlMatrix{T} ) where {T}
218- axes (B,2 ) == axes (A,1 ) && axes (B,1 ) == axes (A,2 ) || throw (DimensionMismatch (" transpose" ))
207+ axes (B, 2 ) == axes (A, 1 ) && axes (B, 1 ) == axes (A, 2 ) || throw (DimensionMismatch (" transpose" ))
219208
220- M,N = size (A)
209+ M, N = size (A)
221210 dev = device ()
222211 queue = global_queue (dev)
223212 cmdbuf = MTLCommandBuffer (queue)
0 commit comments