1- function _matmul! (c:: MPSMatrix , :: Type{T1 } , a:: MPSMatrix , :: Type{T2} , b:: MPSMatrix , :: Type{T3 } , alpha:: Number , beta:: Number , transpose_a, transpose_b) where {T1, T2, T3 }
1+ function _matmul! (c:: MPSMatrix , :: Type{Tc } , a:: MPSMatrix , b:: MPSMatrix , :: Type{Tab } , alpha:: Number , beta:: Number , transpose_a, transpose_b) where {Tc, Tab }
22 graph = MPSGraph ()
33
4- placeA = placeholderTensor (graph, size (a), T2)
5- placeB = placeholderTensor (graph, size (b), T3)
4+ placeA = placeholderTensor (graph, size (a), Tab)
5+ placeB = placeholderTensor (graph, size (b), Tab)
6+
7+ castA, castB = if Tc != Tab
8+ castTensor (graph, placeA, Tc, " castA" ),
9+ castTensor (graph, placeB, Tc, " castB" )
10+ else
11+ placeA, placeB
12+ end
613
714 transA = if transpose_a
8- transposeTensor (graph, placeA , 0 , 1 , " transpose_a" )
15+ transposeTensor (graph, castA , 0 , 1 , " transpose_a" )
916 else
10- placeA
17+ castA
1118 end
1219
1320 transB = if transpose_b
14- transposeTensor (graph, placeB , 0 , 1 , " transpose_b" )
21+ transposeTensor (graph, castB , 0 , 1 , " transpose_b" )
1522 else
16- placeB
23+ castB
1724 end
1825
1926 matmul = matrixMultiplicationWithPrimaryTensor (graph, transB, transA)
2027
2128 afteralpha = if alpha == 1
2229 matmul
2330 else
24- alphatensor = constantWithScalar (graph, alpha, T1 )
31+ alphatensor = constantWithScalar (graph, alpha, Tc )
2532 multiplicationWithPrimaryTensor (graph, alphatensor, matmul)
2633 end
2734
@@ -33,9 +40,9 @@ function _matmul!(c::MPSMatrix, ::Type{T1}, a::MPSMatrix, ::Type{T2}, b::MPSMatr
3340 afterbeta = if beta == 0
3441 afteralpha
3542 else
36- placeC = placeholderTensor (graph, UInt .(size (c)), T1 )
43+ placeC = placeholderTensor (graph, UInt .(size (c)), Tc )
3744 feed[placeC] = MPSGraphTensorData (c)
38- betatensor = constantWithScalar (graph, beta, T1 )
45+ betatensor = constantWithScalar (graph, beta, Tc )
3946 betaC = multiplicationWithPrimaryTensor (graph, betatensor, placeC)
4047 additionWithPrimaryTensor (graph, afteralpha, betaC)
4148 end
@@ -46,12 +53,12 @@ function _matmul!(c::MPSMatrix, ::Type{T1}, a::MPSMatrix, ::Type{T2}, b::MPSMatr
4653 return MPSNDArray (resultdata)
4754end
4855
49- function graph_matmul! (c:: MtlArray{T1 , N} , a:: MtlArray{T2 , N} , b:: MtlArray{T3 , N} , alpha:: Number = true , beta:: Number = false , transpose_a = false , transpose_b = false ) where {T1, T2, T3 , N}
50- resultndarr = _matmul! (MPSMatrix (c), T1 , MPSMatrix (a), T2, MPSMatrix (b), T3 , alpha, beta, transpose_a, transpose_b)
56+ function graph_matmul! (c:: MtlArray{Tc , N} , a:: MtlArray{Tab , N} , b:: MtlArray{Tab , N} , alpha:: Number = true , beta:: Number = false , transpose_a = false , transpose_b = false ) where {Tc, Tab , N}
57+ resultndarr = _matmul! (MPSMatrix (c), Tc , MPSMatrix (a), MPSMatrix (b), Tab , alpha, beta, transpose_a, transpose_b)
5158 return exportToMtlArray! (c, resultndarr)
5259end
5360
54- function graph_matvecmul! (c:: MtlVector{T1 } , a:: MtlMatrix{T2 } , b:: MtlVector{T3 } , alpha:: Number = true , beta:: Number = false , transpose = false ) where {T1, T2, T3 }
55- resultndarr = _matmul! (MPSMatrix (c), T1 , MPSMatrix (a), T2, MPSMatrix (b), T3 , alpha, beta, transpose, false )
61+ function graph_matvecmul! (c:: MtlVector{Tc } , a:: MtlMatrix{Tab } , b:: MtlVector{Tab } , alpha:: Number = true , beta:: Number = false , transpose = false ) where {Tc, Tab }
62+ resultndarr = _matmul! (MPSMatrix (c), Tc , MPSMatrix (a), MPSMatrix (b), Tab , alpha, beta, transpose, false )
5663 return exportToMtlArray! (c, resultndarr)
5764end
0 commit comments