1+
2+ @static if isdefined (Base, :OncePerProcess ) # VERSION >= v"1.12.0-DEV.1421"
3+ const default_exec_desc = OncePerProcess {MPSGraphExecutionDescriptor} () do
4+ compDesc = MPSGraphCompilationDescriptor ()
5+ # Use optimization level 0 to avoid operations being moved to the neural engine
6+ compDesc. optimizationLevel = MPSGraphOptimizationLevel0
7+
8+ execDesc = MPSGraphExecutionDescriptor ()
9+ execDesc. compilationDescriptor = compDesc
10+ execDesc
11+ end
12+ else
13+ const _default_exec_desc:: Ref{MPSGraphExecutionDescriptor} = Ref {MPSGraphExecutionDescriptor} ()
14+ function default_exec_desc ()
15+ if ! isassigned (_default_exec_desc)
16+ compDesc = MPSGraphCompilationDescriptor ()
17+ # Use optimization level 0 to avoid operations being moved to the neural engine
18+ compDesc. optimizationLevel = MPSGraphOptimizationLevel0
19+
20+ _default_exec_desc[] = MPSGraphExecutionDescriptor ()
21+ _default_exec_desc[]. compilationDescriptor = compDesc
22+ end
23+ _default_exec_desc[]
24+ end
25+ end
26+
27+
128function _matmul! (c:: MtlArray{Tc} , a:: MtlArray{Tab, Na} , b:: MtlArray{Tab, Nb} , alpha:: Number , beta:: Number , transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
229 graph = MPSGraph ()
330
@@ -11,9 +38,10 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, a
1138 placeC => MPSGraphTensorData (c)
1239 )
1340
14- # cast to Float32 for better performance
15- castA = castTensor (graph, placeA, Float32, " castA" )
16- castB = castTensor (graph, placeB, Float32, " castB" )
41+ # cast to output eltype if input type is an integer type
42+ castT = Tab <: Integer ? Tc : Tab
43+ castA = castTensor (graph, placeA, castT, " castA" )
44+ castB = castTensor (graph, placeB, castT, " castB" )
1745
1846 transA = transpose_a ? transposeTensor (graph, castA, Na- 2 , Na- 1 , " transpose_a" ) : castA
1947 transB = transpose_b ? transposeTensor (graph, castB, Nb- 2 , Nb- 1 , " transpose_b" ) : castB
@@ -35,13 +63,13 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, a
3563 matmul = matrixMultiplicationWithPrimaryTensor (graph, broadcastB, broadcastA)
3664
3765 afteralpha = let
38- alphatensor = constantWithScalar (graph, alpha, Float32 )
66+ alphatensor = constantWithScalar (graph, alpha, castT )
3967 multiplicationWithPrimaryTensor (graph, alphatensor, matmul)
4068 end
4169
4270 afterbeta = let
43- betatensor = constantWithScalar (graph, beta, Float32 )
44- castplaceC = castTensor (graph, placeC, Float32 , " castplaceC" )
71+ betatensor = constantWithScalar (graph, beta, castT )
72+ castplaceC = castTensor (graph, placeC, castT , " castplaceC" )
4573 betaC = multiplicationWithPrimaryTensor (graph, betatensor, castplaceC)
4674 afterbeta = additionWithPrimaryTensor (graph, afteralpha, betaC)
4775 end
@@ -53,7 +81,7 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, a
5381 )
5482
5583 cmdbuf = MPSCommandBuffer (Metal. global_queue (device ()))
56- encode! (cmdbuf, graph, NSDictionary (feeds), NSDictionary (resultdict))
84+ encode! (cmdbuf, graph, NSDictionary (feeds), nil, NSDictionary (resultdict), default_exec_desc ( ))
5785 commit! (cmdbuf)
5886 wait_completed (cmdbuf)
5987
0 commit comments