Skip to content

Commit 3f607e1

Browse files
Merge pull request #100 from chengchingwen/master
implementation for batch-wise matrix multiplication
2 parents c82a76d + 83e359c commit 3f607e1

File tree

6 files changed

+201
-0
lines changed

6 files changed

+201
-0
lines changed

src/NNlib.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ end
1414

1515
include("activation.jl")
1616
include("softmax.jl")
17+
include("batched/batchedmul.jl")
1718
include("gemm.jl")
1819
include("conv.jl")
1920
include("pooling.jl")

src/batched/batchedadjtrans.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using LinearAlgebra
2+
import Base: -
3+
4+
"""
5+
BatchedTranspose{T, N, S} <: AbstractBatchedMatrix{T, N}
6+
Batched transpose. Transpose a batch of matrix.
7+
"""
8+
struct BatchedTranspose{T, S} <: AbstractArray{T, 3}
9+
parent::S
10+
BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X)
11+
end
12+
13+
"""
14+
batched_transpose(A)
15+
Lazy batched transpose.
16+
"""
17+
batched_transpose(A::AbstractArray{T}) where T = BatchedTranspose(A)
18+
batched_transpose(A::BatchedTranspose) = A.parent
19+
20+
"""
21+
BatchedAdjoint{T, N, S} <: AbstractBatchedMatrix{T, N}
22+
Batched ajoint. Transpose a batch of matrix.
23+
"""
24+
struct BatchedAdjoint{T, S} <: AbstractArray{T, 3}
25+
parent::S
26+
BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X)
27+
end
28+
29+
"""
30+
batched_adjoint(A)
31+
Lazy batched adjoint.
32+
"""
33+
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
34+
batched_adjoint(A::BatchedAdjoint) = A.parent
35+
36+
BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
37+
BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)
38+
39+
40+
const BatchedAdjOrTrans{T, S} = Union{BatchedTranspose{T, S}, BatchedAdjoint{T, S}}
41+
42+
LinearAlgebra.wrapperop(A::BatchedAdjoint) = batched_adjoint
43+
LinearAlgebra.wrapperop(B::BatchedTranspose) = batched_transpose
44+
45+
# AbstractArray Interface
46+
Base.length(A::BatchedAdjOrTrans) = length(A.parent)
47+
Base.size(m::BatchedAdjOrTrans) = (size(m.parent, 2), size(m.parent, 1), size(m.parent, 3))
48+
Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.parent, 3))
49+
50+
Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
51+
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
52+
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
53+
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjOrTrans, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
54+
55+
Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
56+
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)
57+
Base.similar(A::BatchedAdjOrTrans, T::Type) = similar(A.parent, T, size(A))
58+
Base.similar(A::BatchedAdjOrTrans) = similar(A.parent, size(A))
59+
60+
Base.parent(A::BatchedAdjOrTrans) = A.parent
61+
62+
(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent)
63+
(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)
64+
65+
Base.copy(A::BatchedTranspose) = BatchedTranspose(copy(A.parent))
66+
Base.copy(A::BatchedAdjoint) = BatchedAdjoint(copy(A.parent))
67+

src/batched/batchedmul.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# batch-wise matrix multiplication
2+
# wrapper for batched_gemm!
3+
export batched_mul, batched_transpose, batched_adjoint
4+
5+
6+
include("./batchedadjtrans.jl")
7+
8+
function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T
9+
size(A, 3) == size(B, 3) || throw(DimensionMismatch("batch size mismatch"))
10+
batched_mul!(similar(A, (size(A, 1), size(B, 2), size(A, 3))), A, B)
11+
end
12+
13+
"""
14+
batched_mul!(C, A, B) -> C
15+
batched `mul!`.
16+
"""
17+
function batched_mul! end
18+
19+
_unbatch(A) = A
20+
_unbatch(A::BatchedAdjOrTrans) = A.parent
21+
22+
# bmm
23+
const _BATCHED_MATRIX_LIST = [
24+
(:(AbstractArray{T, 3}), 'N'),
25+
(:(BatchedTranspose{T, <:AbstractArray{T, 3}}), 'T'),
26+
(:(BatchedAdjoint{T, <:AbstractArray{T, 3}}), 'C')
27+
]
28+
29+
for (TA, transA) in _BATCHED_MATRIX_LIST, (TB, transB) in _BATCHED_MATRIX_LIST
30+
@eval begin
31+
function batched_mul!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where T
32+
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
33+
C
34+
end
35+
36+
37+
end
38+
end

src/gemm.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,50 @@ for (gemm, elt) in gemm_datatype_mappings
5656
end
5757
end
5858
end
59+
60+
for (gemm, elt) in gemm_datatype_mappings
61+
@eval begin
62+
@inline function batched_gemm!(transA::AbstractChar,
63+
transB::AbstractChar,
64+
alpha::($elt),
65+
A::AbstractArray{$elt, 3},
66+
B::AbstractArray{$elt, 3},
67+
beta::($elt),
68+
C::AbstractArray{$elt, 3})
69+
@assert !Base.has_offset_axes(A, B, C)
70+
@assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch"
71+
m = size(A, transA == 'N' ? 1 : 2)
72+
ka = size(A, transA == 'N' ? 2 : 1)
73+
kb = size(B, transB == 'N' ? 1 : 2)
74+
n = size(B, transB == 'N' ? 2 : 1)
75+
if ka != kb || m != size(C,1) || n != size(C,2)
76+
throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
77+
end
78+
LinearAlgebra.BLAS.chkstride1(A)
79+
LinearAlgebra.BLAS.chkstride1(B)
80+
LinearAlgebra.BLAS.chkstride1(C)
81+
82+
ptrA = Base.unsafe_convert(Ptr{$elt}, A)
83+
ptrB = Base.unsafe_convert(Ptr{$elt}, B)
84+
ptrC = Base.unsafe_convert(Ptr{$elt}, C)
85+
86+
for k in 1:size(A, 3)
87+
ccall((@blasfunc($(gemm)), libblas), Nothing,
88+
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
89+
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
90+
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
91+
Ref{BlasInt}),
92+
transA, transB, m, n,
93+
ka, alpha, ptrA, max(1,Base.stride(A,2)),
94+
ptrB, max(1,Base.stride(B,2)), beta, ptrC,
95+
max(1,Base.stride(C,2)))
96+
97+
ptrA += size(A, 1) * size(A, 2) * sizeof($elt)
98+
ptrB += size(B, 1) * size(B, 2) * sizeof($elt)
99+
ptrC += size(C, 1) * size(C, 2) * sizeof($elt)
100+
end
101+
102+
C
103+
end
104+
end
105+
end

test/batchedmul.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
function bmm_test(a,b; transA = false, transB = false)
2+
bs = size(a,3)
3+
transA && (a = permutedims(a, [2,1,3]))
4+
transB && (b = permutedims(b, [2,1,3]))
5+
c = []
6+
for i = 1:bs
7+
push!(c, a[:,:,i]*b[:,:,i])
8+
end
9+
10+
cat(c...; dims = 3)
11+
end
12+
13+
function bmm_adjtest(a,b; adjA = false, adjB = false)
14+
bs = size(a,3)
15+
c = []
16+
for i = 1:bs
17+
ai = adjA ? adjoint(a[:,:,i]) : a[:,:,i]
18+
bi = adjB ? adjoint(b[:,:,i]) : b[:,:,i]
19+
push!(c, ai*bi)
20+
end
21+
22+
cat(c...; dims = 3)
23+
end
24+
25+
@testset "Batched Matrix Multiplication" begin
26+
A = randn(7,5,3)
27+
B = randn(5,7,3)
28+
C = randn(7,6,3)
29+
30+
@test batched_mul(A, B) == bmm_test(A, B)
31+
@test batched_mul(batched_transpose(A), batched_transpose(B)) == bmm_test(A, B; transA = true, transB = true)
32+
@test batched_mul(batched_transpose(A), C) == bmm_test(A, C; transA = true)
33+
@test batched_mul(A, batched_transpose(A)) == bmm_test(A, A; transB = true)
34+
35+
36+
cA = randn(Complex{Float64}, 7,5,3)
37+
cB = randn(Complex{Float64}, 5,7,3)
38+
cC = randn(Complex{Float64}, 7,6,3)
39+
40+
@test batched_mul(cA, cB) == bmm_adjtest(cA, cB)
41+
@test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) == bmm_adjtest(cA, cB; adjA = true, adjB = true)
42+
@test batched_mul(batched_adjoint(cA), cC) == bmm_adjtest(cA, cC; adjA = true)
43+
@test batched_mul(cA, batched_adjoint(cA)) == bmm_adjtest(cA, cA; adjB = true)
44+
45+
@test batched_transpose(batched_transpose(A)) == A
46+
@test batched_adjoint(batched_adjoint(cA)) == cA
47+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ using NNlib, Test
22

33
include("activation.jl")
44
include("conv.jl")
5+
include("batchedmul.jl")
56
include("pooling.jl")
67
include("inference.jl")

0 commit comments

Comments
 (0)