Skip to content

Commit 6237de3

Browse files
committed
port batched transpose type, remove transpose keyword
1 parent 4a8251b commit 6237de3

File tree

4 files changed

+116
-32
lines changed

4 files changed

+116
-32
lines changed

src/NNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module NNlib
33
using Requires, Libdl
44

55
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
6-
softmax, logsoftmax, maxpool, meanpool, batchedmul
6+
softmax, logsoftmax, maxpool, meanpool, batched_mul, batched_transpose
77

88
include("numeric.jl")
99
include("activation.jl")

src/batchedadjtrans.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using LinearAlgebra
2+
import Base: -
3+
4+
"""
5+
BatchedTranspose{T, N, S} <: AbstractBatchedMatrix{T, N}
6+
Batched transpose. Transpose a batch of matrix.
7+
"""
8+
struct BatchedTranspose{T, S} <: AbstractArray{T, 3}
9+
parent::S
10+
BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X)
11+
end
12+
13+
"""
14+
batched_transpose(A)
15+
Lazy batched transpose.
16+
"""
17+
batched_transpose(A::AbstractArray{T}) where T = BatchedTranspose(A)
18+
19+
20+
"""
21+
BatchedAdjoint{T, N, S} <: AbstractBatchedMatrix{T, N}
22+
Batched ajoint. Transpose a batch of matrix.
23+
"""
24+
struct BatchedAdjoint{T, S} <: AbstractArray{T, 3}
25+
parent::S
26+
BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X)
27+
end
28+
29+
"""
30+
batched_adjoint(A)
31+
Lazy batched adjoint.
32+
"""
33+
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
34+
35+
36+
BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
37+
BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)
38+
39+
40+
const BatchedAdjOrTrans{T, S} = Union{BatchedTranspose{T, S}, BatchedAdjoint{T, S}}
41+
42+
LinearAlgebra.wrapperop(A::BatchedAdjoint) = batched_adjoint
43+
LinearAlgebra.wrapperop(B::BatchedTranspose) = batched_transpose
44+
45+
# AbstractArray Interface
46+
Base.length(A::BatchedAdjOrTrans) = length(A.parent)
47+
Base.size(m::BatchedAdjOrTrans) = (size(m.parent, 2), size(m.parent, 1), size(m.parent, 3))
48+
Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.parent, 3))
49+
50+
Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
51+
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
52+
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
53+
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjOrTrans, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
54+
55+
Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
56+
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)
57+
Base.similar(A::BatchedAdjOrTrans, T::Type) = similar(A.parent, T, size(A))
58+
Base.similar(A::BatchedAdjOrTrans) = similar(A.parent, size(A))
59+
60+
Base.parent(A::BatchedAdjOrTrans) = A.parent
61+
62+
(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent)
63+
(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)
64+
65+
Base.copy(A::BatchedTranspose) = BatchedTranspose(copy(A.parent))
66+
Base.copy(A::BatchedAdjoint) = BatchedAdjoint(copy(A.parent))
67+

src/batchedmul.jl

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,53 @@
11
# batch-wise matrix multiplication
22
# wrapper for batched_gemm!
33

4-
function batchedmul(a::AbstractArray{T, 3}, b::AbstractArray{T, 3};
5-
transA::Bool = false, transB::Bool = false) where T
6-
(bs = size(a, 3)) == size(b, 3) || error("batch size mismatch")
7-
res = similar(a, size(a, transA ? 2 : 1), size(b, transB ? 1 : 2), bs)
8-
batched_mul!(res, a, b; transA=transA, transB=transB)
9-
return res
10-
end
4+
include("./batchedadjtrans.jl")
115

12-
function batched_mul!(C::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::AbstractArray{T, 3};
13-
transA::Bool = false, transB::Bool = false) where T
14-
At = transA ? 'T' : 'N'
15-
Bt = transB ? 'T' : 'N'
16-
batched_gemm!(At, Bt, one(T), A, B, zero(T), C)
17-
C
6+
function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T
7+
size(A, 3) == size(B, 3) || throw(DimensionMismatch("batch size mismatch"))
8+
batched_mul!(similar(A, (size(A, 1), size(B, 2), size(A, 3))), A, B)
189
end
1910

20-
#gradient function for batchedmul
21-
function ∇batchedmul::AbstractArray{T, 3}, a::AbstractArray{T, 3}, b::AbstractArray{T, 3};
22-
transA::Bool = false, transB::Bool = false) where T
23-
if transA
24-
if transB
25-
(batchedmul(b, Δ; transA=true, transB=true), batchedmul(Δ, a; transA=true, transB=true))
26-
else
27-
(batchedmul(b, Δ; transB=true), batchedmul(a, Δ))
28-
end
29-
else
30-
if transB
31-
(batchedmul(Δ, b), batchedmul(Δ, a; transA=true))
32-
else
33-
(batchedmul(Δ, b; transB=true), batchedmul(a, Δ; transA=true))
11+
"""
12+
batched_mul!(C, A, B) -> C
13+
batched `mul!`.
14+
"""
15+
function batched_mul! end
16+
17+
_unbatch(A) = A
18+
_unbatch(A::BatchedAdjOrTrans) = A.parent
19+
20+
# bmm
21+
const _BATCHED_MATRIX_LIST = [
22+
(:(AbstractArray{T, 3}), 'N'),
23+
(:(BatchedTranspose{T, <:AbstractArray{T, 3}}), 'T'),
24+
(:(BatchedAdjoint{T, <:AbstractArray{T, 3}}), 'C')
25+
]
26+
27+
for (TA, transA) in _BATCHED_MATRIX_LIST, (TB, transB) in _BATCHED_MATRIX_LIST
28+
@eval begin
29+
function batched_mul!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where T
30+
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
31+
C
3432
end
33+
34+
3535
end
3636
end
37+
38+
function ∇batched_mul::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T
39+
(batched_mul(Δ, batched_transpose(B)), batched_mul(batched_transpose(A), Δ))
40+
end
41+
42+
43+
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T
44+
(batched_mul(Δ, batched_transpose(B)), batched_mul(A, Δ))
45+
end
46+
47+
function ∇batched_mul::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T
48+
(batched_mul(Δ, B), batched_mul(batched_transpose(A), Δ))
49+
end
50+
51+
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T
52+
(batched_mul(batched_transpose(Δ), batched_transpose(B)), batched_mul(batched_transpose(A), batched_transpose(Δ)))
53+
end

test/batchedmul.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ end
1515
B = randn(5,7,3)
1616
C = randn(7,6,3)
1717

18-
@test batchedmul(A, B) == bmm_test(A, B)
19-
@test batchedmul(A, B; transA = true, transB = true) == bmm_test(A, B; transA = true, transB = true)
20-
@test batchedmul(A, C; transA = true) == bmm_test(A, C; transA = true)
21-
@test batchedmul(A, A; transB = true) == bmm_test(A, A; transB = true)
18+
@test batched_mul(A, B) == bmm_test(A, B)
19+
@test batched_mul(batched_transpose(A), batched_transpose(B)) == bmm_test(A, B; transA = true, transB = true)
20+
@test batched_mul(batched_transpose(A), C) == bmm_test(A, C; transA = true)
21+
@test batched_mul(A, batched_transpose(A)) == bmm_test(A, A; transB = true)
2222
end

0 commit comments

Comments
 (0)