|
1 | | -function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab} |
| 1 | +function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab} |
2 | 2 | graph = MPSGraph() |
3 | 3 |
|
4 | 4 | placeA = placeholderTensor(graph, size(a), Tab) |
5 | 5 | placeB = placeholderTensor(graph, size(b), Tab) |
| 6 | + outputTensorData = MPSGraphTensorData(c) |
| 7 | + |
| 8 | + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( |
| 9 | + placeA => MPSGraphTensorData(a), |
| 10 | + placeB => MPSGraphTensorData(b) |
| 11 | + ) |
6 | 12 |
|
7 | 13 | castA, castB = if Tc != Tab |
8 | 14 | castTensor(graph, placeA, Tc, "castA"), |
@@ -32,51 +38,32 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T |
32 | 38 | multiplicationWithPrimaryTensor(graph, alphatensor, matmul) |
33 | 39 | end |
34 | 40 |
|
35 | | - feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( |
36 | | - placeA => MPSGraphTensorData(a), |
37 | | - placeB => MPSGraphTensorData(b) |
38 | | - ) |
39 | | - |
40 | 41 | afterbeta = if beta == 0 |
41 | 42 | afteralpha |
42 | 43 | else |
43 | 44 | placeC = placeholderTensor(graph, size(c), Tc) |
44 | | - feeds[placeC] = MPSGraphTensorData(c) |
| 45 | + feeds[placeC] = outputTensorData |
45 | 46 | betatensor = constantWithScalar(graph, beta, Tc) |
46 | 47 | betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC) |
47 | 48 | additionWithPrimaryTensor(graph, afteralpha, betaC) |
48 | 49 | end |
49 | 50 |
|
50 | | - # Encode and commit matmul kernel |
51 | | - cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) |
52 | | - resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([afterbeta])) |
53 | | - commitAndContinue!(cmdbuf) |
| 51 | + resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( |
| 52 | + afterbeta => outputTensorData |
| 53 | + ) |
54 | 54 |
|
55 | | - resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[afterbeta])) |
| 55 | + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) |
| 56 | + encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict)) |
| 57 | + commit!(cmdbuf) |
| 58 | + wait_completed(cmdbuf) |
56 | 59 |
|
57 | | - return cmdbuf, MPSNDArray(resultdata) |
| 60 | + return c |
58 | 61 | end |
59 | 62 |
|
60 | 63 | function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N} |
61 | | - cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b) |
62 | | - |
63 | | - commit!(cmdbuf) do cmdBuf |
64 | | - exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset) |
65 | | - end |
66 | | - |
67 | | - wait_completed(cmdbuf) |
68 | | - |
69 | | - return c |
| 64 | + _matmul!(c, a, b, alpha, beta, transpose_a, transpose_b) |
70 | 65 | end |
71 | 66 |
|
72 | 67 | function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab} |
73 | | - cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false) |
74 | | - |
75 | | - commit!(cmdbuf) do cmdBuf |
76 | | - exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset) |
77 | | - end |
78 | | - |
79 | | - wait_completed(cmdbuf) |
80 | | - |
81 | | - return c |
| 68 | + _matmul!(c, a, b, alpha, beta, transpose, false) |
82 | 69 | end |
0 commit comments