Skip to content

Commit 7267040

Browse files
committed
Support more complex broadcasting behaviour
Will unblock NNlib issue 614
1 parent 82e0f8b commit 7267040

File tree

1 file changed

+26
-33
lines changed

1 file changed

+26
-33
lines changed

lib/mpsgraphs/matmul.jl

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,55 @@
1-
function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
1+
function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
22
graph = MPSGraph()
33

44
placeA = placeholderTensor(graph, size(a), Tab)
55
placeB = placeholderTensor(graph, size(b), Tab)
6-
outputTensorData = MPSGraphTensorData(c)
6+
placeC = placeholderTensor(graph, size(c), Tc)
77

88
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
99
placeA => MPSGraphTensorData(a),
10-
placeB => MPSGraphTensorData(b)
10+
placeB => MPSGraphTensorData(b),
11+
placeC => MPSGraphTensorData(c)
1112
)
1213

13-
castA, castB = if Tab != Float32
14-
castTensor(graph, placeA, Float32, "castA"),
15-
castTensor(graph, placeB, Float32, "castB")
16-
else
17-
placeA, placeB
18-
end
14+
# cast to Float32 for better performance
15+
castA = castTensor(graph, placeA, Float32, "castA")
16+
castB = castTensor(graph, placeB, Float32, "castB")
1917

20-
transA = if transpose_a
21-
transposeTensor(graph, castA, 0, 1, "transpose_a")
22-
else
23-
castA
24-
end
18+
transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA
19+
transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB
20+
21+
nBatchA = Na == 2 ? 1 : size(transA)[1]
22+
nBatchB = Nb == 2 ? 1 : size(transB)[1]
2523

26-
transB = if transpose_b
27-
transposeTensor(graph, castB, 0, 1, "transpose_b")
24+
# for batched matmul between different sized tensors
25+
broadcastA, broadcastB = if nBatchA == nBatchB
26+
transA, transB
27+
elseif Na == 1
28+
broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB
29+
elseif Nb == 1
30+
transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...]))
2831
else
29-
castB
32+
transA, transB
3033
end
3134

32-
matmul = matrixMultiplicationWithPrimaryTensor(graph, transB, transA)
35+
matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA)
3336

34-
afteralpha = if isone(alpha)
35-
matmul
36-
else
37+
afteralpha = let
3738
alphatensor = constantWithScalar(graph, alpha, Float32)
3839
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
3940
end
4041

41-
afterbeta = if iszero(beta)
42-
afteralpha
43-
else
44-
placeC = placeholderTensor(graph, size(c), Tc)
45-
feeds[placeC] = outputTensorData
42+
afterbeta = let
4643
betatensor = constantWithScalar(graph, beta, Float32)
4744
castplaceC = castTensor(graph, placeC, Float32, "castplaceC")
4845
betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC)
49-
additionWithPrimaryTensor(graph, afteralpha, betaC)
46+
afterbeta = additionWithPrimaryTensor(graph, afteralpha, betaC)
5047
end
5148

52-
castC = if Tc != Float32
53-
castTensor(graph, afterbeta, Tc, "castC")
54-
else
55-
afterbeta
56-
end
49+
castC = castTensor(graph, afterbeta, Tc, "castC")
5750

5851
resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
59-
castC => outputTensorData
52+
castC => feeds[placeC]
6053
)
6154

6255
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))

0 commit comments

Comments
 (0)