Skip to content

Commit 777ed11

Browse files
committed
Use optimization Level 0 by default to disable use of neural engine
1 parent e44c8d7 commit 777ed11

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

lib/mpsgraphs/matmul.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
2+
@static if isdefined(Base, :OncePerProcess) # VERSION >= v"1.12.0-DEV.1421"
3+
const default_exec_desc = OncePerProcess{MPSGraphExecutionDescriptor}() do
4+
compDesc = MPSGraphCompilationDescriptor()
5+
# Use optimization level 0 to avoid operations being moved to the neural engine
6+
compDesc.optimizationLevel = MPSGraphOptimizationLevel0
7+
8+
execDesc = MPSGraphExecutionDescriptor()
9+
execDesc.compilationDescriptor = compDesc
10+
execDesc
11+
end
12+
else
13+
const _default_exec_desc::Ref{MPSGraphExecutionDescriptor} = Ref{MPSGraphExecutionDescriptor}()
14+
function default_exec_desc()
15+
if !isassigned(_default_exec_desc)
16+
compDesc = MPSGraphCompilationDescriptor()
17+
# Use optimization level 0 to avoid operations being moved to the neural engine
18+
compDesc.optimizationLevel = MPSGraphOptimizationLevel0
19+
20+
_default_exec_desc[] = MPSGraphExecutionDescriptor()
21+
_default_exec_desc[].compilationDescriptor = compDesc
22+
end
23+
_default_exec_desc[]
24+
end
25+
end
26+
27+
128
function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
229
graph = MPSGraph()
330

@@ -11,9 +38,10 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, a
1138
placeC => MPSGraphTensorData(c)
1239
)
1340

14-
# cast to Float32 for better performance
15-
castA = castTensor(graph, placeA, Float32, "castA")
16-
castB = castTensor(graph, placeB, Float32, "castB")
41+
# cast to output eltype if input type is an integer type
42+
castT = Tab <: Integer ? Tc : Tab
43+
castA = castTensor(graph, placeA, castT, "castA")
44+
castB = castTensor(graph, placeB, castT, "castB")
1745

1846
transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA
1947
transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB
@@ -35,13 +63,13 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, a
3563
matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA)
3664

3765
afteralpha = let
38-
alphatensor = constantWithScalar(graph, alpha, Float32)
66+
alphatensor = constantWithScalar(graph, alpha, castT)
3967
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
4068
end
4169

4270
afterbeta = let
43-
betatensor = constantWithScalar(graph, beta, Float32)
44-
castplaceC = castTensor(graph, placeC, Float32, "castplaceC")
71+
betatensor = constantWithScalar(graph, beta, castT)
72+
castplaceC = castTensor(graph, placeC, castT, "castplaceC")
4573
betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC)
4674
afterbeta = additionWithPrimaryTensor(graph, afteralpha, betaC)
4775
end
@@ -53,7 +81,7 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, a
5381
)
5482

5583
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
56-
encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict))
84+
encode!(cmdbuf, graph, NSDictionary(feeds), nil, NSDictionary(resultdict), default_exec_desc())
5785
commit!(cmdbuf)
5886
wait_completed(cmdbuf)
5987

0 commit comments

Comments
 (0)