|
1 | | -function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab} |
| 1 | +function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb} |
2 | 2 | graph = MPSGraph() |
3 | 3 |
|
4 | 4 | placeA = placeholderTensor(graph, size(a), Tab) |
5 | 5 | placeB = placeholderTensor(graph, size(b), Tab) |
6 | | - outputTensorData = MPSGraphTensorData(c) |
| 6 | + placeC = placeholderTensor(graph, size(c), Tc) |
7 | 7 |
|
8 | 8 | feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( |
9 | 9 | placeA => MPSGraphTensorData(a), |
10 | | - placeB => MPSGraphTensorData(b) |
| 10 | + placeB => MPSGraphTensorData(b), |
| 11 | + placeC => MPSGraphTensorData(c) |
11 | 12 | ) |
12 | 13 |
|
13 | | - castA, castB = if Tab != Float32 |
14 | | - castTensor(graph, placeA, Float32, "castA"), |
15 | | - castTensor(graph, placeB, Float32, "castB") |
16 | | - else |
17 | | - placeA, placeB |
18 | | - end |
| 14 | + # cast to Float32 for better performance |
| 15 | + castA = castTensor(graph, placeA, Float32, "castA") |
| 16 | + castB = castTensor(graph, placeB, Float32, "castB") |
19 | 17 |
|
20 | | - transA = if transpose_a |
21 | | - transposeTensor(graph, castA, 0, 1, "transpose_a") |
22 | | - else |
23 | | - castA |
24 | | - end |
| 18 | + transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA |
| 19 | + transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB |
| 20 | + |
| 21 | + nBatchA = Na == 2 ? 1 : size(transA)[1] |
| 22 | + nBatchB = Nb == 2 ? 1 : size(transB)[1] |
25 | 23 |
|
26 | | - transB = if transpose_b |
27 | | - transposeTensor(graph, castB, 0, 1, "transpose_b") |
| 24 | + # for batched matmul between different sized tensors |
| 25 | + broadcastA, broadcastB = if nBatchA == nBatchB |
| 26 | + transA, transB |
| 27 | + elseif Na == 1 |
| 28 | + broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB |
| 29 | + elseif Nb == 1 |
| 30 | + transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...])) |
28 | 31 | else |
29 | | - castB |
| 32 | + transA, transB |
30 | 33 | end |
31 | 34 |
|
32 | | - matmul = matrixMultiplicationWithPrimaryTensor(graph, transB, transA) |
| 35 | + matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA) |
33 | 36 |
|
34 | | - afteralpha = if isone(alpha) |
35 | | - matmul |
36 | | - else |
| 37 | + afteralpha = let |
37 | 38 | alphatensor = constantWithScalar(graph, alpha, Float32) |
38 | 39 | multiplicationWithPrimaryTensor(graph, alphatensor, matmul) |
39 | 40 | end |
40 | 41 |
|
41 | | - afterbeta = if iszero(beta) |
42 | | - afteralpha |
43 | | - else |
44 | | - placeC = placeholderTensor(graph, size(c), Tc) |
45 | | - feeds[placeC] = outputTensorData |
| 42 | + afterbeta = let |
46 | 43 | betatensor = constantWithScalar(graph, beta, Float32) |
47 | 44 | castplaceC = castTensor(graph, placeC, Float32, "castplaceC") |
48 | 45 | betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC) |
49 | | - additionWithPrimaryTensor(graph, afteralpha, betaC) |
| 46 | + afterbeta = additionWithPrimaryTensor(graph, afteralpha, betaC) |
50 | 47 | end |
51 | 48 |
|
52 | | - castC = if Tc != Float32 |
53 | | - castTensor(graph, afterbeta, Tc, "castC") |
54 | | - else |
55 | | - afterbeta |
56 | | - end |
| 49 | + castC = castTensor(graph, afterbeta, Tc, "castC") |
57 | 50 |
|
58 | 51 | resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( |
59 | | - castC => outputTensorData |
| 52 | + castC => feeds[placeC] |
60 | 53 | ) |
61 | 54 |
|
62 | 55 | cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) |
|
0 commit comments