1- function  _matmul! (c:: MPSMatrix , :: Type{Tc} , a:: MPSMatrix , b:: MPSMatrix , :: Type{Tab} , alpha:: Number , beta:: Number , transpose_a, transpose_b) where  {Tc, Tab}
1+ function  _matmul! (cmdbuf :: MPSCommandBuffer ,  c:: MPSMatrix , :: Type{Tc} , a:: MPSMatrix , b:: MPSMatrix , :: Type{Tab} , alpha:: Number , beta:: Number , transpose_a, transpose_b) where  {Tc, Tab}
22    graph =  MPSGraph ()
33
44    placeA =  placeholderTensor (graph, size (a), Tab)
@@ -7,6 +7,10 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
77    castA, castB =  if  Tc !=  Tab
88        castTensor (graph, placeA, Tc, " castA"  ),
99            castTensor (graph, placeB, Tc, " castB"  )
10+ 
11+     #  castA, castB = if Tab != Float32
12+     #      castTensor(graph, placeA, Float32, "castA"),
13+     #          castTensor(graph, placeB, Float32, "castB")
1014    else 
1115        placeA, placeB
1216    end 
@@ -47,18 +51,28 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
4751        additionWithPrimaryTensor (graph, afteralpha, betaC)
4852    end 
4953
54+     castC =  if  Tc !=  Float32
55+         afterbeta
56+         #  castTensor(graph, afterbeta, Tc, "castC")
57+     else 
58+         afterbeta
59+     end 
60+ 
5061    #  Encode and commit matmul kernel
51-     cmdbuf  =   MPSCommandBuffer (Metal . global_queue ( device () ))
52-     resultdict =  encode! (cmdbuf, graph, NSDictionary (feeds), NSArray ([afterbeta ]))
62+     #  resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([afterbeta] ))
63+     resultdict =  encode! (cmdbuf, graph, NSDictionary (feeds), NSArray ([castC ]))
5364    commitAndContinue! (cmdbuf)
5465
55-     resultdata =  MPSGraphTensorData (id {MPSGraphTensorData} (resultdict[afterbeta]))
66+     #  resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[afterbeta]))
67+     resultdata =  MPSGraphTensorData (id {MPSGraphTensorData} (resultdict[castC]))
5668
57-     return  cmdbuf,  MPSNDArray (resultdata)
69+     return  MPSNDArray (resultdata)
5870end 
5971
6072function  graph_matmul! (c:: MtlArray{Tc, N} , a:: MtlArray{Tab, N} , b:: MtlArray{Tab, N} , alpha:: Number  =  true , beta:: Number  =  false , transpose_a =  false , transpose_b =  false ) where  {Tc, Tab, N}
61-     cmdbuf, resultndarr =  _matmul! (MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose_a, transpose_b)
73+     cmdbuf =  MPSCommandBuffer (Metal. global_queue (device ()))
74+ 
75+     resultndarr =  _matmul! (cmdbuf, MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose_a, transpose_b)
6276
6377    commit! (cmdbuf) do  cmdBuf
6478        exportDataWithCommandBuffer (resultndarr, cmdBuf, c. data[], Tc, c. offset)
@@ -70,7 +84,9 @@ function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab,
7084end 
7185
7286function  graph_matvecmul! (c:: MtlVector{Tc} , a:: MtlMatrix{Tab} , b:: MtlVector{Tab} , alpha:: Number  =  true , beta:: Number  =  false , transpose =  false ) where  {Tc, Tab}
73-     cmdbuf, resultndarr =  _matmul! (MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose, false )
87+     cmdbuf =  MPSCommandBuffer (Metal. global_queue (device ()))
88+ 
89+     resultndarr =  _matmul! (cmdbuf, MPSMatrix (c), Tc, MPSMatrix (a), MPSMatrix (b), Tab, alpha, beta, transpose, false )
7490
7591    commit! (cmdbuf) do  cmdBuf
7692        exportDataWithCommandBuffer (resultndarr, cmdBuf, c. data[], Tc, c. offset)
0 commit comments