Skip to content

Commit 4a8251b

Browse files
Roger-luochengchingwen
authored andcommitted
implementation for batch-wise matrix multiplication
1 parent 2bd7e8a commit 4a8251b

File tree

5 files changed

+114
-1
lines changed

5 files changed

+114
-1
lines changed

src/NNlib.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ module NNlib
33
using Requires, Libdl
44

55
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
6-
softmax, logsoftmax, maxpool, meanpool
6+
softmax, logsoftmax, maxpool, meanpool, batchedmul
77

88
include("numeric.jl")
99
include("activation.jl")
1010
include("softmax.jl")
1111
include("logsoftmax.jl")
1212
include("linalg.jl")
13+
include("batchedmul.jl")
1314
include("conv.jl")
1415
include("cubroadcast.jl")
1516

src/batchedmul.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# batch-wise matrix multiplication
2+
# wrapper for batched_gemm!
3+
4+
function batchedmul(a::AbstractArray{T, 3}, b::AbstractArray{T, 3};
5+
transA::Bool = false, transB::Bool = false) where T
6+
(bs = size(a, 3)) == size(b, 3) || error("batch size mismatch")
7+
res = similar(a, size(a, transA ? 2 : 1), size(b, transB ? 1 : 2), bs)
8+
batched_mul!(res, a, b; transA=transA, transB=transB)
9+
return res
10+
end
11+
12+
function batched_mul!(C::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::AbstractArray{T, 3};
13+
transA::Bool = false, transB::Bool = false) where T
14+
At = transA ? 'T' : 'N'
15+
Bt = transB ? 'T' : 'N'
16+
batched_gemm!(At, Bt, one(T), A, B, zero(T), C)
17+
C
18+
end
19+
20+
#gradient function for batchedmul
21+
function ∇batchedmul::AbstractArray{T, 3}, a::AbstractArray{T, 3}, b::AbstractArray{T, 3};
22+
transA::Bool = false, transB::Bool = false) where T
23+
if transA
24+
if transB
25+
(batchedmul(b, Δ; transA=true, transB=true), batchedmul(Δ, a; transA=true, transB=true))
26+
else
27+
(batchedmul(b, Δ; transB=true), batchedmul(a, Δ))
28+
end
29+
else
30+
if transB
31+
(batchedmul(Δ, b), batchedmul(Δ, a; transA=true))
32+
else
33+
(batchedmul(Δ, b; transB=true), batchedmul(a, Δ; transA=true))
34+
end
35+
end
36+
end

src/linalg.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,58 @@ for (gemm, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32))
2626
transA, transB, M, N, K,
2727
alpha, A, lda, B, ldb, beta, C, ldc)
2828
end
29+
30+
## borrow BatchedRoutines.jl
31+
# batched gemm for 3d-array
32+
# C[:,:,i] := alpha*op(A[:,:,i])*op(B[:,:,i]) + beta*C[:,:,i], where:
33+
# i is the specific batch number,
34+
# op(X) is one of op(X) = X, or op(X) = XT, or op(X) = XH,
35+
# alpha and beta are scalars,
36+
# A, B and C are 3d Array:
37+
# op(A) is an m-by-k-by-b 3d Array,
38+
# op(B) is a k-by-n-by-b 3d Array,
39+
# C is an m-by-n-by-b 3d Array.
40+
function batched_gemm!(transA::AbstractChar,
41+
transB::AbstractChar,
42+
alpha::($elty),
43+
A::AbstractArray{$elty, 3},
44+
B::AbstractArray{$elty, 3},
45+
beta::($elty),
46+
C::AbstractArray{$elty, 3})
47+
@assert !LinearAlgebra.BLAS.has_offset_axes(A, B, C)
48+
@assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch"
49+
m = size(A, transA == 'N' ? 1 : 2)
50+
ka = size(A, transA == 'N' ? 2 : 1)
51+
kb = size(B, transB == 'N' ? 1 : 2)
52+
n = size(B, transB == 'N' ? 2 : 1)
53+
if ka != kb || m != size(C,1) || n != size(C,2)
54+
throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
55+
end
56+
LinearAlgebra.BLAS.chkstride1(A)
57+
LinearAlgebra.BLAS.chkstride1(B)
58+
LinearAlgebra.BLAS.chkstride1(C)
59+
60+
ptrA = Base.unsafe_convert(Ptr{$elty}, A)
61+
ptrB = Base.unsafe_convert(Ptr{$elty}, B)
62+
ptrC = Base.unsafe_convert(Ptr{$elty}, C)
63+
64+
for k in 1:size(A, 3)
65+
ccall((LinearAlgebra.BLAS.@blasfunc($gemm), LinearAlgebra.BLAS.libblas), Cvoid,
66+
(Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt},
67+
Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{LinearAlgebra.BLAS.BlasInt},
68+
Ptr{$elty}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$elty}, Ptr{$elty},
69+
Ref{LinearAlgebra.BLAS.BlasInt}),
70+
transA, transB, m, n,
71+
ka, alpha, ptrA, max(1,Base.stride(A,2)),
72+
ptrB, max(1,Base.stride(B,2)), beta, ptrC,
73+
max(1,Base.stride(C,2)))
74+
75+
ptrA += size(A, 1) * size(A, 2) * sizeof($elty)
76+
ptrB += size(B, 1) * size(B, 2) * sizeof($elty)
77+
ptrC += size(C, 1) * size(C, 2) * sizeof($elty)
78+
end
79+
80+
C
81+
end
2982
end
3083
end

test/batchedmul.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
function bmm_test(a,b; transA = false, transB = false)
2+
bs = size(a,3)
3+
transA && (a = permutedims(a, [2,1,3]))
4+
transB && (b = permutedims(b, [2,1,3]))
5+
c = []
6+
for i = 1:bs
7+
push!(c, a[:,:,i]*b[:,:,i])
8+
end
9+
10+
cat(c...; dims = 3)
11+
end
12+
13+
@testset "Batched Matrix Multiplication" begin
14+
A = randn(7,5,3)
15+
B = randn(5,7,3)
16+
C = randn(7,6,3)
17+
18+
@test batchedmul(A, B) == bmm_test(A, B)
19+
@test batchedmul(A, B; transA = true, transB = true) == bmm_test(A, B; transA = true, transB = true)
20+
@test batchedmul(A, C; transA = true) == bmm_test(A, C; transA = true)
21+
@test batchedmul(A, A; transB = true) == bmm_test(A, A; transB = true)
22+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using NNlib, Test
44

55
include("activation.jl")
66
include("conv.jl")
7+
include("batchedmul.jl")
78

89
xs = [-100_000, -100_000.]
910
@test softmax(xs) [0.5, 0.5]

0 commit comments

Comments
 (0)