Skip to content

Commit a4e8216

Browse files
committed
Remove copying
1 parent 6f8cc61 commit a4e8216

File tree

3 files changed

+21
-38
lines changed

3 files changed

+21
-38
lines changed

lib/mps/MPS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using BFloat16s
2121
const MtlFloat = Union{Float32, Float16}
2222

2323
const MPSShape = NSArray#{NSNumber}
24-
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))
24+
Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))
2525

2626
# Valid combination of input (A and B matrices) and output (C) types
2727
const MPS_VALID_MATMUL_TYPES =

lib/mpsgraphs/matmul.jl

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
1+
function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab}
22
graph = MPSGraph()
33

44
placeA = placeholderTensor(graph, size(a), Tab)
55
placeB = placeholderTensor(graph, size(b), Tab)
6+
outputTensorData = MPSGraphTensorData(c)
7+
8+
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
9+
placeA => MPSGraphTensorData(a),
10+
placeB => MPSGraphTensorData(b)
11+
)
612

713
castA, castB = if Tc != Tab
814
castTensor(graph, placeA, Tc, "castA"),
@@ -32,51 +38,32 @@ function _matmul!(c::MPSMatrix, ::Type{Tc}, a::MPSMatrix, b::MPSMatrix, ::Type{T
3238
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
3339
end
3440

35-
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
36-
placeA => MPSGraphTensorData(a),
37-
placeB => MPSGraphTensorData(b)
38-
)
39-
4041
afterbeta = if beta == 0
4142
afteralpha
4243
else
4344
placeC = placeholderTensor(graph, size(c), Tc)
44-
feeds[placeC] = MPSGraphTensorData(c)
45+
feeds[placeC] = outputTensorData
4546
betatensor = constantWithScalar(graph, beta, Tc)
4647
betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
4748
additionWithPrimaryTensor(graph, afteralpha, betaC)
4849
end
4950

50-
# Encode and commit matmul kernel
51-
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
52-
resultdict = encode!(cmdbuf, graph, NSDictionary(feeds), NSArray([afterbeta]))
53-
commitAndContinue!(cmdbuf)
51+
resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
52+
afterbeta => outputTensorData
53+
)
5454

55-
resultdata = MPSGraphTensorData(id{MPSGraphTensorData}(resultdict[afterbeta]))
55+
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
56+
encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict))
57+
commit!(cmdbuf)
58+
wait_completed(cmdbuf)
5659

57-
return cmdbuf, MPSNDArray(resultdata)
60+
return c
5861
end
5962

6063
function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
61-
cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose_a, transpose_b)
62-
63-
commit!(cmdbuf) do cmdBuf
64-
exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset)
65-
end
66-
67-
wait_completed(cmdbuf)
68-
69-
return c
64+
_matmul!(c, a, b, alpha, beta, transpose_a, transpose_b)
7065
end
7166

7267
function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
73-
cmdbuf, resultndarr = _matmul!(MPSMatrix(c), Tc, MPSMatrix(a), MPSMatrix(b), Tab, alpha, beta, transpose, false)
74-
75-
commit!(cmdbuf) do cmdBuf
76-
exportDataWithCommandBuffer(resultndarr, cmdBuf, c.data[], Tc, c.offset)
77-
end
78-
79-
wait_completed(cmdbuf)
80-
81-
return c
68+
_matmul!(c, a, b, alpha, beta, transpose, false)
8269
end

lib/mpsgraphs/tensor.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function Base.size(td::MPSGraphTensor)
1414
end
1515

1616
function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...)
17-
mpsshape = convert(MPSShape, shape)
17+
mpsshape = convert(MPSShape, reverse(shape))
1818
return placeholderTensor(graph, mpsshape, args...)
1919
end
2020
function placeholderTensor(graph::MPSGraph, shape::MPSShape, dataType::Type, name = "placeholder tensor")
@@ -53,9 +53,7 @@ function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType, rowByt
5353
rowBytes:rowBytes::NSUInteger]::id{MPSGraphTensorData}
5454
return tensor
5555
end
56-
# MPSGraphTensorData(matrix::MtlMatrix{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
57-
MPSGraphTensorData(matrix::MtlMatrix) = MPSGraphTensorData(MPSMatrix(matrix))
58-
MPSGraphTensorData(arr::MtlArray{<:Any, 3}) = MPSGraphTensorData(MPSMatrix(arr))
56+
MPSGraphTensorData(matrix::MtlArray{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
5957

6058
function MPSGraphTensorData(matrix::MPSMatrix)
6159
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
@@ -82,8 +80,6 @@ function MPSGraphTensorData(vector::MPSVector)
8280
@objc [tensor::id{MPSGraphTensorData} initWithMPSVector:vector::id{MPSVector}]::id{MPSGraphTensorData}
8381
return tensor
8482
end
85-
# MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)
86-
MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))
8783

8884
# rank must be between 1 and 16 inclusive
8985
function MPSGraphTensorData(vector::MPSVector, rank)

0 commit comments

Comments
 (0)