Skip to content

Commit 1432aca

Browse files
committed
Use MPSGraph for matrix multiplication
1 parent bc12f52 commit 1432aca

File tree

7 files changed

+267
-23
lines changed

7 files changed

+267
-23
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,40 @@ module MPSGraphs
1010

1111
using ..Metal
1212
using .MTL
13-
using .MPS: MPSDataType, MPSMatrix, MPSVector, MPSShape, MPSNDArray
13+
using .MPS: MPSDataType, MPSMatrix, MPSVector, MPSShape, MPSNDArray, exportToMtlArray!
1414

1515
using CEnum
1616
using ObjectiveC, .Foundation, .Dispatch
1717

18+
# Valid combination of input (A and B matrices) and output (C) types
19+
# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
20+
const MPSGRAPH_VALID_MATMUL_TYPES =
21+
[
22+
# (Int8, Float16),
23+
# (Int8, Float32),
24+
# (Int16, Float32),
25+
(Float16, Float16),
26+
(Float16, Float32),
27+
(Float32, Float32),
28+
]
29+
30+
const MPSGRAPH_VALID_MATVECMUL_TYPES =
31+
[
32+
(Int8, Float16),
33+
(Int8, Float32),
34+
(Int16, Float32),
35+
(Float16, Float16),
36+
(Float16, Float32),
37+
(Float32, Float32),
38+
]
39+
1840
include("libmpsgraph.jl")
1941

2042
include("core.jl")
2143
include("tensor.jl")
2244
include("operations.jl")
2345
include("random.jl")
2446

47+
include("matmul.jl")
48+
2549
end

lib/mpsgraphs/matmul.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
2+
graph = MPSGraph()
3+
4+
placeA = placeholderTensor(graph, size(a), Tab)
5+
placeB = placeholderTensor(graph, size(b), Tab)
6+
7+
castA, castB = if Tc != Tab
8+
castTensor(graph, placeA, Tc, "castA"),
9+
castTensor(graph, placeB, Tc, "castB")
10+
else
11+
placeA, placeB
12+
end
13+
14+
transA = if transpose_a
15+
transposeTensor(graph, castA, 0, 1, "transpose_a")
16+
else
17+
castA
18+
end
19+
20+
transB = if transpose_b
21+
transposeTensor(graph, castB, 0, 1, "transpose_b")
22+
else
23+
castB
24+
end
25+
26+
matmul = matrixMultiplicationWithPrimaryTensor(graph, transB, transA)
27+
28+
afteralpha = if alpha == 1
29+
matmul
30+
else
31+
alphatensor = constantWithScalar(graph, alpha, Tc)
32+
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
33+
end
34+
35+
feed = Dict(
36+
placeA => MPSGraphTensorData(a),
37+
placeB => MPSGraphTensorData(b)
38+
)
39+
40+
afterbeta = if beta == 0
41+
afteralpha
42+
else
43+
placeC = placeholderTensor(graph, UInt.(size(c)), Tc)
44+
feed[placeC] = MPSGraphTensorData(c)
45+
betatensor = constantWithScalar(graph, beta, Tc)
46+
betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
47+
additionWithPrimaryTensor(graph, afteralpha, betaC)
48+
end
49+
50+
res = run(graph, feed, [afterbeta])
51+
resultdata = only(Dict{MPSGraphTensor, MPSGraphTensorData}(res)).second
52+
53+
return MPSNDArray(resultdata)
54+
end
55+
56+
function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
57+
resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
58+
return exportToMtlArray!(c, resultndarr)
59+
end
60+
61+
function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
62+
resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
63+
return exportToMtlArray!(c, resultndarr)
64+
end

lib/mpsgraphs/operations.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,51 @@
11

2-
function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name="matmul")
2+
function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast")
3+
obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
4+
toType:toType::MPSDataType
5+
name:name::id{NSString}]::id{MPSGraphTensor}
6+
MPSGraphTensor(obj)
7+
end
8+
9+
function constantWithScalar(graph::MPSGraph, scalar::Number, dataType)
10+
obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64
11+
dataType:dataType::MPSDataType]::id{MPSGraphTensor}
12+
MPSGraphTensor(obj)
13+
end
14+
15+
function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul")
316
obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
417
secondaryTensor:secondary::id{MPSGraphTensor}
518
name:name::id{NSString}]::id{MPSGraphTensor}
619
MPSGraphTensor(obj)
720
end
821

22+
function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "mul")
23+
obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
24+
secondaryTensor:secondary::id{MPSGraphTensor}
25+
name:name::id{NSString}]::id{MPSGraphTensor}
26+
MPSGraphTensor(obj)
27+
end
28+
function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "add")
29+
obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
30+
secondaryTensor:secondary::id{MPSGraphTensor}
31+
name:name::id{NSString}]::id{MPSGraphTensor}
32+
MPSGraphTensor(obj)
33+
end
34+
35+
function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose")
36+
obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
37+
dimension:dimension::NSUInteger
38+
withDimension:withDimension::NSUInteger
39+
name:name::id{NSString}]::id{MPSGraphTensor}
40+
MPSGraphTensor(obj)
41+
end
42+
43+
function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity")
44+
obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
45+
name:name::id{NSString}]::id{MPSGraphTensor}
46+
MPSGraphTensor(obj)
47+
end
48+
949
run(graph::MPSGraph, feeds::Dict, targetTensors::Vector) = run(graph, MPSGraphTensorDataDictionary(feeds), NSArray(targetTensors))
1050
function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
1151
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}

lib/mpsgraphs/tensor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ function MPSGraphTensorData(vector::MPSVector)
8282
@objc [tensor::id{MPSGraphTensorData} initWithMPSVector:vector::id{MPSVector}]::id{MPSGraphTensorData}
8383
return tensor
8484
end
85-
MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)
86-
# MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))
85+
# MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)
86+
MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))
8787

8888
# rank must be between 1 and 16 inclusive
8989
function MPSGraphTensorData(vector::MPSVector, rank)

src/linalg.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,24 @@ using LinearAlgebra
22
using LinearAlgebra: MulAddMul, wrap
33
using .MPS
44
using .MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
5+
using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
6+
graph_matmul!, graph_matvecmul!
7+
8+
@inline function supports_mps_matmul(A, B, C, valid_types)
9+
MPS.is_supported(device(A)) &&
10+
eltype(A) == eltype(B) &&
11+
(eltype(A), eltype(C)) in valid_types
12+
end
13+
14+
@inline function supports_mpsgraph_matmul(A, B, C, valid_types)
15+
MPS.is_supported(device(A)) &&
16+
eltype(A) == eltype(B) &&
17+
(eltype(A), eltype(C)) in valid_types &&
18+
# TODO: remove this limitation
19+
A.offset == 0 &&
20+
B.offset == 0 &&
21+
C.offset == 0
22+
end
523

624
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
725
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
@@ -28,13 +46,10 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
2846
transA = tA == 'T' || tA == 'C'
2947
transB = tB == 'T' || tB == 'C'
3048

31-
typA = eltype(A)
32-
typB = eltype(B)
33-
typC = eltype(C)
34-
35-
# If possible, dispatch to performance shaders
36-
if MPS.is_supported(device()) &&
37-
typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
49+
# If possible, dispatch to MPSGraphs, then performance shaders
50+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
51+
graph_matmul!(C, A, B, alpha, beta, transA, transB)
52+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES)
3853
matmul!(C, A, B, alpha, beta, transA, transB)
3954
else
4055
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
@@ -66,13 +81,10 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
6681

6782
transA = tA == 'T' || tA == 'C'
6883

69-
typA = eltype(A)
70-
typB = eltype(B)
71-
typC = eltype(C)
72-
73-
# If possible, dispatch to performance shaders
74-
if MPS.is_supported(device()) &&
75-
typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
84+
# If possible, dispatch to MPSGraphs, then performance shaders
85+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
86+
graph_matvecmul!(C, A, B, alpha, beta, transA)
87+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES)
7688
matvecmul!(C, A, B, alpha, beta, transA)
7789
else
7890
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)

test/mps/linalg.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,19 @@ if MPS.is_supported(device())
3434
end
3535

3636
@testset "batched matrix matrix multiplication" begin
37-
N = 10
37+
M = 8
38+
N = 7
39+
P = 9
3840
batch_size = 3
3941

40-
rows_a = N
42+
rows_a = M
4143
cols_a = N
4244

4345
rows_b = N
44-
cols_b = N
46+
cols_b = P
4547

46-
rows_c = rows_a
47-
cols_c = cols_b
48+
rows_c = M
49+
cols_c = P
4850

4951
alpha = Float64(1)
5052
beta = Float64(1)

test/mpsgraphs/linalg.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using LinearAlgebra
2+
3+
4+
if MPS.is_supported(device())
5+
6+
@testset "mixed-precision matrix matrix multiplication" begin
7+
N = 10
8+
rows_a = N
9+
cols_a = N
10+
11+
rows_b = N
12+
cols_b = N
13+
14+
rows_c = rows_a
15+
cols_c = cols_b
16+
17+
alpha = Float64(1)
18+
beta = Float64(1)
19+
20+
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
21+
arr_a = rand(input_jl_type, (rows_a, cols_a))
22+
arr_b = rand(input_jl_type, (rows_b, cols_b))
23+
arr_c = zeros(accum_jl_type, (rows_c, cols_c))
24+
25+
buf_a = MtlArray{input_jl_type}(arr_a)
26+
buf_b = MtlArray{input_jl_type}(arr_b)
27+
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
28+
29+
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
30+
31+
MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
32+
33+
@test all(Array(buf_c) .≈ truth_c)
34+
end
35+
end
36+
37+
# XXX: Batched matlmul not yet working
38+
@testset "batched matrix matrix multiplication" begin
39+
M = 8
40+
N = 7
41+
P = 9
42+
batch_size = 3
43+
44+
rows_a = M
45+
cols_a = N
46+
47+
rows_b = N
48+
cols_b = P
49+
50+
rows_c = M
51+
cols_c = P
52+
53+
alpha = Float64(1)
54+
beta = Float64(1)
55+
56+
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
57+
arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
58+
arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
59+
arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
60+
61+
buf_a = MtlArray{input_jl_type}(arr_a)
62+
buf_b = MtlArray{input_jl_type}(arr_b)
63+
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
64+
65+
truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
66+
for i in 1:batch_size
67+
@views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
68+
end
69+
70+
MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
71+
72+
@test all(Array(buf_c) .≈ truth_c)
73+
end
74+
end
75+
76+
@testset "mixed-precision matrix vector multiplication" begin
77+
N = 10
78+
rows = N
79+
cols = N
80+
81+
alpha = Float64(1)
82+
beta = Float64(0)
83+
84+
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
85+
arr_a = rand(input_jl_type, (rows, cols))
86+
arr_b = rand(input_jl_type, (rows))
87+
arr_c = zeros(accum_jl_type, (rows))
88+
89+
buf_a = MtlArray{input_jl_type}(arr_a)
90+
buf_b = MtlArray{input_jl_type}(arr_b)
91+
buf_c = MtlArray{accum_jl_type}(undef, (rows))
92+
93+
truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
94+
95+
MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
96+
97+
@test all(Array(buf_c) .≈ truth_c)
98+
# @test Array(buf_c) ≈ truth_c
99+
end
100+
end
101+
102+
end

0 commit comments

Comments
 (0)