Skip to content

Commit 8c7e260

Browse files
committed
Use MPSGraph for matrix multiplication
1 parent c55bf2a commit 8c7e260

File tree

6 files changed

+240
-17
lines changed

6 files changed

+240
-17
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,37 @@ module MPSGraphs
1010

1111
using ..Metal
1212
using .MTL
13-
using .MPS: MPSDataType, MPSMatrix, MPSVector, MPSShape, MPSNDArray
13+
using .MPS: MPSDataType, MPSMatrix, MPSVector, MPSShape, MPSNDArray, exportToMtlArray!
1414

1515
using CEnum
1616
using ObjectiveC, .Foundation, .Dispatch
1717

18+
# Valid combination of input (A and B matrices) and output (C) types
19+
# TODO: support the commented type combinations
20+
const MPSGRAPH_VALID_MATMUL_TYPES =
21+
[
22+
# (Int8, Float16),
23+
# (Int8, Float32),
24+
# (Int16, Float32),
25+
(Float16, Float16),
26+
# (Float16, Float32),
27+
(Float32, Float32),
28+
]
29+
30+
const MPSGRAPH_VALID_MATVECMUL_TYPES =
31+
[
32+
(Float16, Float16),
33+
# (Float16, Float32),
34+
(Float32, Float32),
35+
]
36+
1837
include("libmpsgraph.jl")
1938

2039
include("core.jl")
2140
include("tensor.jl")
2241
include("operations.jl")
2342
include("random.jl")
2443

44+
include("matmul.jl")
45+
2546
end

lib/mpsgraphs/matmul.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
function _matmul!(c::MPSMatrix, ::Type{T1}, a::MPSMatrix, ::Type{T2}, b::MPSMatrix, ::Type{T3}, alpha::Number, beta::Number, transpose_a, transpose_b) where {T1, T2, T3}
2+
graph = MPSGraph()
3+
4+
placeA = placeholderTensor(graph, size(a), T2)
5+
placeB = placeholderTensor(graph, size(b), T3)
6+
7+
transA = if transpose_a
8+
transposeTensor(graph, placeA, 0, 1, "transpose_a")
9+
else
10+
placeA
11+
end
12+
13+
transB = if transpose_b
14+
transposeTensor(graph, placeB, 0, 1, "transpose_b")
15+
else
16+
placeB
17+
end
18+
19+
matmul = matrixMultiplicationWithPrimaryTensor(graph, transB, transA)
20+
21+
afteralpha = if alpha == 1
22+
matmul
23+
else
24+
alphatensor = constantWithScalar(graph, alpha, T1)
25+
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
26+
end
27+
28+
feed = Dict(
29+
placeA => MPSGraphTensorData(a),
30+
placeB => MPSGraphTensorData(b)
31+
)
32+
33+
afterbeta = if beta == 0
34+
afteralpha
35+
else
36+
placeC = placeholderTensor(graph, UInt.(size(c)), T1)
37+
feed[placeC] = MPSGraphTensorData(c)
38+
betatensor = constantWithScalar(graph, beta, T1)
39+
betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
40+
additionWithPrimaryTensor(graph, afteralpha, betaC)
41+
end
42+
43+
res = run(graph, feed, [afterbeta])
44+
resultdata = only(Dict{MPSGraphTensor, MPSGraphTensorData}(res)).second
45+
46+
return MPSNDArray(resultdata)
47+
end
48+
49+
function graph_matmul!(c::MtlArray{T1, N}, a::MtlArray{T2, N}, b::MtlArray{T3, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {T1, T2, T3, N}
50+
resultndarr = _matmul!(MPSMatrix(c), T1, MPSMatrix(a), T2, MPSMatrix(b), T3, alpha, beta, transpose_a, transpose_b)
51+
return exportToMtlArray!(c, resultndarr)
52+
end
53+
54+
function graph_matvecmul!(c::MtlVector{T1}, a::MtlMatrix{T2}, b::MtlVector{T3}, alpha::Number = true, beta::Number = false, transpose = false) where {T1, T2, T3}
55+
resultndarr = _matmul!(MPSMatrix(c), T1, MPSMatrix(a), T2, MPSMatrix(b), T3, alpha, beta, transpose, false)
56+
return exportToMtlArray!(c, resultndarr)
57+
end

lib/mpsgraphs/operations.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,44 @@
11

2+
function constantWithScalar(graph::MPSGraph, scalar::Number, dataType)
3+
obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64
4+
dataType:dataType::MPSDataType]::id{MPSGraphTensor}
5+
MPSGraphTensor(obj)
6+
end
7+
28
function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name="matmul")
39
obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
410
secondaryTensor:secondary::id{MPSGraphTensor}
511
name:name::id{NSString}]::id{MPSGraphTensor}
612
MPSGraphTensor(obj)
713
end
814

15+
function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul")
16+
obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
17+
secondaryTensor:secondary::id{MPSGraphTensor}
18+
name:name::id{NSString}]::id{MPSGraphTensor}
19+
MPSGraphTensor(obj)
20+
end
21+
function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul")
22+
obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
23+
secondaryTensor:secondary::id{MPSGraphTensor}
24+
name:name::id{NSString}]::id{MPSGraphTensor}
25+
MPSGraphTensor(obj)
26+
end
27+
28+
function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose")
29+
obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
30+
dimension:dimension::NSUInteger
31+
withDimension:withDimension::NSUInteger
32+
name:name::id{NSString}]::id{MPSGraphTensor}
33+
MPSGraphTensor(obj)
34+
end
35+
36+
function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity")
37+
obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
38+
name:name::id{NSString}]::id{MPSGraphTensor}
39+
MPSGraphTensor(obj)
40+
end
41+
942
run(graph::MPSGraph, feeds::Dict, targetTensors::Vector) = run(graph, MPSGraphTensorDataDictionary(feeds), NSArray(targetTensors))
1043
function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
1144
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}

lib/mpsgraphs/tensor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ function MPSGraphTensorData(vector::MPSVector)
8282
@objc [tensor::id{MPSGraphTensorData} initWithMPSVector:vector::id{MPSVector}]::id{MPSGraphTensorData}
8383
return tensor
8484
end
85-
MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)
86-
# MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))
85+
# MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)
86+
MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))
8787

8888
# rank must be between 1 and 16 inclusive
8989
function MPSGraphTensorData(vector::MPSVector, rank)

src/linalg.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,24 @@ using LinearAlgebra
22
using LinearAlgebra: MulAddMul, wrap
33
using .MPS
44
using .MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
5+
using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
6+
graph_matmul!, graph_matvecmul!
7+
8+
@inline function supports_mps_matmul(A,B,C, valid_types)
9+
MPS.is_supported(device(A)) &&
10+
eltype(A) == eltype(B) &&
11+
(eltype(A), eltype(C)) in valid_types
12+
end
13+
14+
@inline function supports_mpsgraph_matmul(A,B,C, valid_types)
15+
MPS.is_supported(device(A)) &&
16+
eltype(A) == eltype(B) &&
17+
(eltype(A), eltype(C)) in valid_types &&
18+
# TODO: remove this limitation
19+
A.offset == 0 &&
20+
B.offset == 0 &&
21+
C.offset == 0
22+
end
523

624
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
725
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
@@ -28,13 +46,10 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
2846
transA = tA == 'T' || tA == 'C'
2947
transB = tB == 'T' || tB == 'C'
3048

31-
typA = eltype(A)
32-
typB = eltype(B)
33-
typC = eltype(C)
34-
35-
# If possible, dispatch to performance shaders
36-
if MPS.is_supported(device()) &&
37-
typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
49+
# If possible, dispatch to MPSGraphs, then performance shaders
50+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
51+
graph_matmul!(C, A, B, alpha, beta, transA, transB)
52+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES)
3853
matmul!(C, A, B, alpha, beta, transA, transB)
3954
else
4055
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
@@ -66,13 +81,10 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
6681

6782
transA = tA == 'T' || tA == 'C'
6883

69-
typA = eltype(A)
70-
typB = eltype(B)
71-
typC = eltype(C)
72-
73-
# If possible, dispatch to performance shaders
74-
if MPS.is_supported(device()) &&
75-
typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
84+
# If possible, dispatch to MPSGraphs, then performance shaders
85+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
86+
graph_matvecmul!(C, A, B, alpha, beta, transA)
87+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES)
7688
matvecmul!(C, A, B, alpha, beta, transA)
7789
else
7890
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)

test/mpsgraphs/linalg.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
using LinearAlgebra
2+
3+
4+
if MPS.is_supported(device())
5+
6+
@testset "mixed-precision matrix matrix multiplication" begin
7+
N = 10
8+
rows_a = N
9+
cols_a = N
10+
11+
rows_b = N
12+
cols_b = N
13+
14+
rows_c = rows_a
15+
cols_c = cols_b
16+
17+
alpha = Float64(1)
18+
beta = Float64(1)
19+
20+
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
21+
arr_a = rand(input_jl_type, (rows_a, cols_a))
22+
arr_b = rand(input_jl_type, (rows_b, cols_b))
23+
arr_c = zeros(accum_jl_type, (rows_c, cols_c))
24+
25+
buf_a = MtlArray{input_jl_type}(arr_a)
26+
buf_b = MtlArray{input_jl_type}(arr_b)
27+
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
28+
29+
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
30+
31+
MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
32+
33+
@test all(Array(buf_c) .≈ truth_c)
34+
end
35+
end
36+
37+
# XXX: Batched matlmul not yet working
38+
# @testset "batched matrix matrix multiplication" begin
39+
# N = 10
40+
# batch_size = 3
41+
42+
# rows_a = N
43+
# cols_a = N
44+
45+
# rows_b = N
46+
# cols_b = N
47+
48+
# rows_c = rows_a
49+
# cols_c = cols_b
50+
51+
# alpha = Float64(1)
52+
# beta = Float64(1)
53+
54+
# @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
55+
# arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
56+
# arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
57+
# arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
58+
59+
# buf_a = MtlArray{input_jl_type}(arr_a)
60+
# buf_b = MtlArray{input_jl_type}(arr_b)
61+
# buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
62+
63+
# truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
64+
# for i in 1:batch_size
65+
# @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
66+
# end
67+
68+
# MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
69+
70+
# @test all(Array(buf_c) .≈ truth_c)
71+
# end
72+
# end
73+
74+
@testset "mixed-precision matrix vector multiplication" begin
75+
N = 10
76+
rows = N
77+
cols = N
78+
79+
alpha = Float64(1)
80+
beta = Float64(0)
81+
82+
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
83+
arr_a = rand(input_jl_type, (rows, cols))
84+
arr_b = rand(input_jl_type, (rows))
85+
arr_c = zeros(accum_jl_type, (rows))
86+
87+
buf_a = MtlArray{input_jl_type}(arr_a)
88+
buf_b = MtlArray{input_jl_type}(arr_b)
89+
buf_c = MtlArray{accum_jl_type}(undef, (rows))
90+
91+
truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
92+
93+
MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
94+
95+
@test all(Array(buf_c) .≈ truth_c)
96+
# @test Array(buf_c) ≈ truth_c
97+
end
98+
end
99+
100+
end

0 commit comments

Comments
 (0)