Skip to content

Commit ea8361d

Browse files
committed
Support more type combinations
1 parent 956fffb commit ea8361d

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@ using CEnum
1616
using ObjectiveC, .Foundation, .Dispatch
1717

1818
# Valid combination of input (A and B matrices) and output (C) types
19-
# TODO: support the commented type combinations
19+
# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
2020
const MPSGRAPH_VALID_MATMUL_TYPES =
2121
[
2222
# (Int8, Float16),
2323
# (Int8, Float32),
2424
# (Int16, Float32),
2525
(Float16, Float16),
26-
# (Float16, Float32),
26+
(Float16, Float32),
2727
(Float32, Float32),
2828
]
2929

3030
const MPSGRAPH_VALID_MATVECMUL_TYPES =
3131
[
32+
(Int8, Float16),
33+
(Int8, Float32),
34+
(Int16, Float32),
3235
(Float16, Float16),
33-
# (Float16, Float32),
36+
(Float16, Float32),
3437
(Float32, Float32),
3538
]
3639

lib/mpsgraphs/matmul.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
1-
function _matmul!(c::MPSMatrix, ::Type{T1}, a::MPSMatrix, ::Type{T2}, b::MPSMatrix, ::Type{T3}, alpha::Number, beta::Number, transpose_a, transpose_b) where {T1, T2, T3}
1+
function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
22
graph = MPSGraph()
33

4-
placeA = placeholderTensor(graph, size(a), T2)
5-
placeB = placeholderTensor(graph, size(b), T3)
4+
placeA = placeholderTensor(graph, size(a), Tab)
5+
placeB = placeholderTensor(graph, size(b), Tab)
6+
7+
castA, castB = if Tc != Tab
8+
castTensor(graph, placeA, Tc, "castA"),
9+
castTensor(graph, placeB, Tc, "castB")
10+
else
11+
placeA, placeB
12+
end
613

714
transA = if transpose_a
8-
transposeTensor(graph, placeA, 0, 1, "transpose_a")
15+
transposeTensor(graph, castA, 0, 1, "transpose_a")
916
else
10-
placeA
17+
castA
1118
end
1219

1320
transB = if transpose_b
14-
transposeTensor(graph, placeB, 0, 1, "transpose_b")
21+
transposeTensor(graph, castB, 0, 1, "transpose_b")
1522
else
16-
placeB
23+
castB
1724
end
1825

1926
matmul = matrixMultiplicationWithPrimaryTensor(graph, transB, transA)
2027

2128
afteralpha = if alpha == 1
2229
matmul
2330
else
24-
alphatensor = constantWithScalar(graph, alpha, T1)
31+
alphatensor = constantWithScalar(graph, alpha, Tc)
2532
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
2633
end
2734

@@ -33,9 +40,9 @@ function _matmul!(c::MPSMatrix, ::Type{T1}, a::MPSMatrix, ::Type{T2}, b::MPSMatr
3340
afterbeta = if beta == 0
3441
afteralpha
3542
else
36-
placeC = placeholderTensor(graph, UInt.(size(c)), T1)
43+
placeC = placeholderTensor(graph, UInt.(size(c)), Tc)
3744
feed[placeC] = MPSGraphTensorData(c)
38-
betatensor = constantWithScalar(graph, beta, T1)
45+
betatensor = constantWithScalar(graph, beta, Tc)
3946
betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
4047
additionWithPrimaryTensor(graph, afteralpha, betaC)
4148
end
@@ -46,12 +53,12 @@ function _matmul!(c::MPSMatrix, ::Type{T1}, a::MPSMatrix, ::Type{T2}, b::MPSMatr
4653
return MPSNDArray(resultdata)
4754
end
4855

49-
function graph_matmul!(c::MtlArray{T1, N}, a::MtlArray{T2, N}, b::MtlArray{T3, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {T1, T2, T3, N}
50-
resultndarr = _matmul!(MPSMatrix(c), T1, MPSMatrix(a), T2, MPSMatrix(b), T3, alpha, beta, transpose_a, transpose_b)
56+
function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
57+
resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
5158
return exportToMtlArray!(c, resultndarr)
5259
end
5360

54-
function graph_matvecmul!(c::MtlVector{T1}, a::MtlMatrix{T2}, b::MtlVector{T3}, alpha::Number = true, beta::Number = false, transpose = false) where {T1, T2, T3}
55-
resultndarr = _matmul!(MPSMatrix(c), T1, MPSMatrix(a), T2, MPSMatrix(b), T3, alpha, beta, transpose, false)
61+
function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
62+
resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
5663
return exportToMtlArray!(c, resultndarr)
5764
end

lib/mpsgraphs/operations.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11

2+
function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name="cast")
3+
obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
4+
toType:toType::MPSDataType
5+
name:name::id{NSString}]::id{MPSGraphTensor}
6+
MPSGraphTensor(obj)
7+
end
8+
29
function constantWithScalar(graph::MPSGraph, scalar::Number, dataType)
310
obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64
411
dataType:dataType::MPSDataType]::id{MPSGraphTensor}

0 commit comments

Comments
 (0)