Skip to content

Commit 6e51f5a

Browse files
committed
Implementation more similar to MPS implementation
1 parent 61dbaaa commit 6e51f5a

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ module MPSGraphs
1010

1111
using ..Metal
1212
using .MTL
13-
using .MPS: MPSDataType, MPSMatrix, MPSVector, MPSShape, MPSNDArray, exportToMtlArray!
13+
using .MPS
14+
using .MPS: MPSDataType, MPSShape, exportDataWithCommandBuffer
1415

1516
using CEnum
1617
using ObjectiveC, .Foundation, .Dispatch

lib/mpsgraphs/matmul.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,51 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
3232
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
3333
end
3434

35-
feed = Dict(
35+
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
3636
placeA => MPSGraphTensorData(a),
3737
placeB => MPSGraphTensorData(b)
3838
)
3939

4040
afterbeta = if beta == 0
4141
afteralpha
4242
else
43-
placeC = placeholderTensor(graph, UInt.(size(c)), Tc)
44-
feed[placeC] = MPSGraphTensorData(c)
43+
placeC = placeholderTensor(graph, size(c), Tc)
44+
feeds[placeC] = MPSGraphTensorData(c)
4545
betatensor = constantWithScalar(graph, beta, Tc)
4646
betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
4747
additionWithPrimaryTensor(graph, afteralpha, betaC)
4848
end
4949

50-
res = run(graph, feed, [afterbeta])
51-
resultdata = only(Dict{MPSGraphTensor, MPSGraphTensorData}(res)).second
50+
# Encode and commit matmul kernel
51+
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
52+
resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([afterbeta]))
53+
commitAndContinue!(cmdbuf)
5254

53-
return MPSNDArray(resultdata)
55+
resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[afterbeta]))
56+
57+
return cmdbuf, MPSNDArray(resultdata)
5458
end
5559

5660
function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
57-
resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
58-
return exportToMtlArray!(c, resultndarr)
61+
cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
62+
63+
commit!(cmdbuf) do cmdBuf
64+
exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset)
65+
end
66+
67+
wait_completed(cmdbuf)
68+
69+
return c
5970
end
6071

6172
function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
62-
resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
63-
return exportToMtlArray!(c, resultndarr)
73+
cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
74+
75+
commit!(cmdbuf) do cmdBuf
76+
exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset)
77+
end
78+
79+
wait_completed(cmdbuf)
80+
81+
return c
6482
end

0 commit comments

Comments
 (0)