Skip to content

Commit 8a4ec87

Browse files
Merge pull request #181 from mcabbott/batch
Add fallback batched_mul!
2 parents 60ac742 + 64db424 commit 8a4ec87

File tree

3 files changed

+86
-43
lines changed

3 files changed

+86
-43
lines changed

src/batched/batchedadjtrans.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,38 @@
11
using LinearAlgebra
22
import Base: -
33

4-
"""
4+
_batched_doc = """
5+
batched_transpose(A::AbstractArray{T,3})
6+
batched_adjoint(A)
7+
8+
Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`.
9+
10+
These exist to control how `batched_mul` behaves,
11+
as it operated on such matrix slices of an array with `ndims(A)==3`.
12+
513
BatchedTranspose{T, N, S} <: AbstractBatchedMatrix{T, N}
6-
Batched transpose. Transpose a batch of matrix.
14+
BatchedAdjoint{T, N, S}
15+
16+
Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose`
717
"""
18+
19+
@doc _batched_doc
820
struct BatchedTranspose{T, S} <: AbstractArray{T, 3}
921
parent::S
1022
BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X)
1123
end
1224

13-
"""
14-
batched_transpose(A)
15-
Lazy batched transpose.
16-
"""
25+
@doc _batched_doc
1726
batched_transpose(A::AbstractArray{T}) where T = BatchedTranspose(A)
1827
batched_transpose(A::BatchedTranspose) = A.parent
1928

20-
"""
21-
BatchedAdjoint{T, N, S} <: AbstractBatchedMatrix{T, N}
22-
Batched ajoint. Transpose a batch of matrix.
23-
"""
29+
@doc _batched_doc
2430
struct BatchedAdjoint{T, S} <: AbstractArray{T, 3}
2531
parent::S
2632
BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X)
2733
end
2834

29-
"""
30-
batched_adjoint(A)
31-
Lazy batched adjoint.
32-
"""
35+
@doc _batched_doc
3336
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
3437
batched_adjoint(A::BatchedAdjoint) = A.parent
3538

src/batched/batchedmul.jl

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,63 @@
22
# wrapper for batched_gemm!
33
export batched_mul, batched_transpose, batched_adjoint
44

5-
65
include("./batchedadjtrans.jl")
76

8-
function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T
9-
size(A, 3) == size(B, 3) || throw(DimensionMismatch("batch size mismatch"))
10-
batched_mul!(similar(A, (size(A, 1), size(B, 2), size(A, 3))), A, B)
7+
"""
8+
batched_mul(A, B) -> C
9+
10+
Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
11+
"""
12+
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
13+
axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch"))
14+
T = promote_type(T1, T2)
15+
C = similar(A, T, (axes(A, 1), axes(B, 2), axes(A, 3)))
16+
batched_mul!(C, A, B)
1117
end
1218

1319
"""
1420
batched_mul!(C, A, B) -> C
15-
batched `mul!`.
21+
22+
In-place batched matrix multiplication,
23+
equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k])` for all `k`.
1624
"""
1725
function batched_mul! end
1826

1927
_unbatch(A) = A
2028
_unbatch(A::BatchedAdjOrTrans) = A.parent
2129

22-
# bmm
23-
const _BATCHED_MATRIX_LIST = [
24-
(:(AbstractArray{T, 3}), 'N'),
25-
(:(BatchedTranspose{T, <:AbstractArray{T, 3}}), 'T'),
26-
(:(BatchedAdjoint{T, <:AbstractArray{T, 3}}), 'C')
30+
# batched_gemm!
31+
32+
const _GemmFloat = Union{Float64, Float32, ComplexF64, ComplexF32}
33+
34+
_BATCHED_GEMM_LIST = [
35+
(:(StridedArray{T, 3}), 'N'),
36+
(:(BatchedTranspose{T, <:StridedArray{T, 3}}), 'T'),
37+
(:(BatchedAdjoint{T, <:StridedArray{T, 3}}), 'C')
2738
]
2839

29-
for (TA, transA) in _BATCHED_MATRIX_LIST, (TB, transB) in _BATCHED_MATRIX_LIST
30-
@eval begin
31-
function batched_mul!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where T
32-
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
33-
C
34-
end
40+
for (TA, transA) in _BATCHED_GEMM_LIST, (TB, transB) in _BATCHED_GEMM_LIST
41+
@eval function batched_mul!(C::StridedArray{T, 3}, A::$TA, B::$TB) where {T<:_GemmFloat}
42+
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
43+
C
44+
end
45+
end
3546

47+
# fallback
3648

49+
_BATCHED_LIST = [
50+
(:(AbstractArray{<:Any, 3}), :identity),
51+
(:(BatchedTranspose{<:Any, <:AbstractArray{<:Any, 3}}), :transpose),
52+
(:(BatchedAdjoint{<:Any, <:AbstractArray{<:Any, 3}}), :adjoint)
53+
]
54+
for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST
55+
@eval function batched_mul!(C::AbstractArray{<:Any, 3}, A::$TA, B::$TB)
56+
axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch"))
57+
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)
58+
A′, B′ = _unbatch(A), _unbatch(B)
59+
@inbounds for k in axes(C, 3)
60+
@views mul!(C[:,:,k], $fA(A′[:,:,k]), $fB(B′[:,:,k]))
61+
end
62+
C
3763
end
3864
end

test/batchedmul.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,40 @@ function bmm_adjtest(a,b; adjA = false, adjB = false)
2222
cat(c...; dims = 3)
2323
end
2424

25-
@testset "Batched Matrix Multiplication" begin
25+
26+
@testset "batched_mul: Float64 * $TB" for TB in [Float64, Float32]
27+
2628
A = randn(7,5,3)
27-
B = randn(5,7,3)
29+
B = randn(TB, 5,7,3)
2830
C = randn(7,6,3)
2931

30-
@test batched_mul(A, B) == bmm_test(A, B)
31-
@test batched_mul(batched_transpose(A), batched_transpose(B)) == bmm_test(A, B; transA = true, transB = true)
32-
@test batched_mul(batched_transpose(A), C) == bmm_test(A, C; transA = true)
33-
@test batched_mul(A, batched_transpose(A)) == bmm_test(A, A; transB = true)
32+
@test batched_mul(A, B) bmm_test(A, B)
33+
@test batched_mul(batched_transpose(A), batched_transpose(B)) bmm_test(A, B; transA = true, transB = true)
34+
@test batched_mul(batched_transpose(A), C) bmm_test(A, C; transA = true)
35+
@test batched_mul(A, batched_transpose(A)) bmm_test(A, A; transB = true)
3436

3537

3638
cA = randn(Complex{Float64}, 7,5,3)
37-
cB = randn(Complex{Float64}, 5,7,3)
39+
cB = randn(Complex{TB}, 5,7,3)
3840
cC = randn(Complex{Float64}, 7,6,3)
3941

40-
@test batched_mul(cA, cB) == bmm_adjtest(cA, cB)
41-
@test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) == bmm_adjtest(cA, cB; adjA = true, adjB = true)
42-
@test batched_mul(batched_adjoint(cA), cC) == bmm_adjtest(cA, cC; adjA = true)
43-
@test batched_mul(cA, batched_adjoint(cA)) == bmm_adjtest(cA, cA; adjB = true)
42+
@test batched_mul(cA, cB) bmm_adjtest(cA, cB)
43+
@test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) bmm_adjtest(cA, cB; adjA = true, adjB = true)
44+
@test batched_mul(batched_adjoint(cA), cC) bmm_adjtest(cA, cC; adjA = true)
45+
@test batched_mul(cA, batched_adjoint(cA)) bmm_adjtest(cA, cA; adjB = true)
46+
47+
@test batched_transpose(batched_transpose(A)) === A
48+
@test batched_adjoint(batched_adjoint(cA)) === cA
49+
50+
TBi = TB==Float64 ? Int64 : Int32
51+
iA = rand(1:99, 7,5,3)
52+
iB = TB.(rand(1:99, 5,7,3))
53+
iC = zeros(Int, 7,6,3)
54+
@test batched_mul(iA, iB) == bmm_adjtest(iA, iB)
55+
@test batched_mul(cA, iB) bmm_adjtest(cA, iB)
56+
57+
@test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 2,2,10))
58+
@test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 10,2,2))
59+
@test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2))
4460

45-
@test batched_transpose(batched_transpose(A)) == A
46-
@test batched_adjoint(batched_adjoint(cA)) == cA
4761
end

0 commit comments

Comments
 (0)