@@ -32,33 +32,51 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
3232 multiplicationWithPrimaryTensor (graph, alphatensor, matmul)
3333 end
3434
35- feed = Dict (
35+ feeds = Dict {MPSGraphTensor, MPSGraphTensorData} (
3636 placeA => MPSGraphTensorData (a),
3737 placeB => MPSGraphTensorData (b)
3838 )
3939
4040 afterbeta = if beta == 0
4141 afteralpha
4242 else
43- placeC = placeholderTensor (graph, UInt .( size (c) ), Tc)
44- feed [placeC] = MPSGraphTensorData (c)
43+ placeC = placeholderTensor (graph, size (c), Tc)
44+ feeds [placeC] = MPSGraphTensorData (c)
4545 betatensor = constantWithScalar (graph, beta, Tc)
4646 betaC = multiplicationWithPrimaryTensor (graph, betatensor, placeC)
4747 additionWithPrimaryTensor (graph, afteralpha, betaC)
4848 end
4949
50- res = run (graph, feed, [afterbeta])
51- resultdata = only (Dict {MPSGraphTensor, MPSGraphTensorData} (res)). second
50+ # Encode and commit matmul kernel
51+ cmdbuf = MPSCommandBuffer (Metal. global_queue (device ()))
52+ resultdict = encode! (cmdbuf, graph, NSDictionary (feeds), NSArray ([afterbeta]))
53+ commitAndContinue! (cmdbuf)
5254
53- return MPSNDArray (resultdata)
55+ resultdata = MPSGraphTensorData (id {MPSGraphTensorData} (resultdict[afterbeta]))
56+
57+ return cmdbuf, MPSNDArray (resultdata)
5458end
5559
5660function graph_matmul! (c:: MtlArray{Tc, N} , a:: MtlArray{Tab, N} , b:: MtlArray{Tab, N} , alpha:: Number = true , beta:: Number = false , transpose_a = false , transpose_b = false ) where {Tc, Tab, N}
57- resultndarr = _matmul! (MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose_a, transpose_b)
58- return exportToMtlArray! (c, resultndarr)
61+ cmdbuf, resultndarr = _matmul! (MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose_a, transpose_b)
62+
63+ commit! (cmdbuf) do cmdBuf
64+ exportDataWithCommandBuffer (resultndarr, cmdBuf, c. data[], Tc, c. offset)
65+ end
66+
67+ wait_completed (cmdbuf)
68+
69+ return c
5970end
6071
6172function graph_matvecmul! (c:: MtlVector{Tc} , a:: MtlMatrix{Tab} , b:: MtlVector{Tab} , alpha:: Number = true , beta:: Number = false , transpose = false ) where {Tc, Tab}
62- resultndarr = _matmul! (MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose, false )
63- return exportToMtlArray! (c, resultndarr)
73+ cmdbuf, resultndarr = _matmul! (MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose, false )
74+
75+ commit! (cmdbuf) do cmdBuf
76+ exportDataWithCommandBuffer (resultndarr, cmdBuf, c. data[], Tc, c. offset)
77+ end
78+
79+ wait_completed (cmdbuf)
80+
81+ return c
6482end
0 commit comments