Skip to content

Commit 54fad8a

Browse files
committed
Still working
1 parent b026d69 commit 54fad8a

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

lib/mpsgraphs/matmul.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
1+
function _matmul!(cmdbuf::MPSCommandBuffer, c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
22
graph = MPSGraph()
33

44
placeA = placeholderTensor(graph, size(a), Tab)
@@ -7,6 +7,10 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
77
castA, castB = if Tc != Tab
88
castTensor(graph, placeA, Tc, "castA"),
99
castTensor(graph, placeB, Tc, "castB")
10+
11+
# castA, castB = if Tab != Float32
12+
# castTensor(graph, placeA, Float32, "castA"),
13+
# castTensor(graph, placeB, Float32, "castB")
1014
else
1115
placeA, placeB
1216
end
@@ -47,18 +51,28 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
4751
additionWithPrimaryTensor(graph, afteralpha, betaC)
4852
end
4953

54+
castC = if Tc != Float32
55+
afterbeta
56+
# castTensor(graph, afterbeta, Tc, "castC")
57+
else
58+
afterbeta
59+
end
60+
5061
# Encode and commit matmul kernel
51-
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
52-
resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([afterbeta]))
62+
# resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([afterbeta]))
63+
resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([castC]))
5364
commitAndContinue!(cmdbuf)
5465

55-
resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[afterbeta]))
66+
# resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[afterbeta]))
67+
resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[castC]))
5668

57-
return cmdbuf, MPSNDArray(resultdata)
69+
return MPSNDArray(resultdata)
5870
end
5971

6072
function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
61-
cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
73+
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
74+
75+
resultndarr = _matmul!(cmdbuf, MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
6276

6377
commit!(cmdbuf) do cmdBuf
6478
exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset)
@@ -70,7 +84,9 @@ function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab,
7084
end
7185

7286
function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
73-
cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
87+
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
88+
89+
resultndarr = _matmul!(cmdbuf, MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
7490

7591
commit!(cmdbuf) do cmdBuf
7692
exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset)

0 commit comments

Comments
 (0)