Skip to content

Commit ff87776

Browse files
author
Michael Abbott
committed
add fallback batched_mul
1 parent 60ac742 commit ff87776

File tree

2 files changed

+57
-19
lines changed

2 files changed

+57
-19
lines changed

src/batched/batchedmul.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,62 @@
22
# wrapper for batched_gemm!
33
export batched_mul, batched_transpose, batched_adjoint
44

5-
65
include("./batchedadjtrans.jl")
76

8-
function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T
9-
size(A, 3) == size(B, 3) || throw(DimensionMismatch("batch size mismatch"))
10-
batched_mul!(similar(A, (size(A, 1), size(B, 2), size(A, 3))), A, B)
7+
"""
8+
batched_mul(A, B) -> C
9+
10+
Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
11+
"""
12+
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
13+
axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch"))
14+
C = similar(A, (axes(A, 1), axes(B, 2), axes(A, 3)))
15+
batched_mul!(C, A, B)
1116
end
1217

1318
"""
1419
batched_mul!(C, A, B) -> C
15-
batched `mul!`.
20+
21+
In-place batched matrix multiplication,
22+
equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k])` for all `k`.
1623
"""
1724
function batched_mul! end
1825

1926
_unbatch(A) = A
2027
_unbatch(A::BatchedAdjOrTrans) = A.parent
2128

22-
# bmm
23-
const _BATCHED_MATRIX_LIST = [
24-
(:(AbstractArray{T, 3}), 'N'),
25-
(:(BatchedTranspose{T, <:AbstractArray{T, 3}}), 'T'),
26-
(:(BatchedAdjoint{T, <:AbstractArray{T, 3}}), 'C')
29+
# batched_gemm!
30+
31+
const _GemmFloat = Union{Float64, Float32, ComplexF64, ComplexF32}
32+
33+
_BATCHED_GEMM_LIST = [
34+
(:(StridedArray{T, 3}), 'N'),
35+
(:(BatchedTranspose{T, <:StridedArray{T, 3}}), 'T'),
36+
(:(BatchedAdjoint{T, <:StridedArray{T, 3}}), 'C')
2737
]
2838

29-
for (TA, transA) in _BATCHED_MATRIX_LIST, (TB, transB) in _BATCHED_MATRIX_LIST
30-
@eval begin
31-
function batched_mul!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where T
32-
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
33-
C
34-
end
39+
for (TA, transA) in _BATCHED_GEMM_LIST, (TB, transB) in _BATCHED_GEMM_LIST
40+
@eval function batched_mul!(C::StridedArray{T, 3}, A::$TA, B::$TB) where {T<:_GemmFloat}
41+
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
42+
C
43+
end
44+
end
3545

46+
# fallback
3647

48+
_BATCHED_LIST = [
49+
(:(AbstractArray{<:Any, 3}), :identity),
50+
(:(BatchedTranspose{<:Any, <:AbstractArray{<:Any, 3}}), :transpose),
51+
(:(BatchedAdjoint{<:Any, <:AbstractArray{<:Any, 3}}), :adjoint)
52+
]
53+
for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST
54+
@eval function batched_mul!(C::AbstractArray{<:Any, 3}, A::$TA, B::$TB)
55+
axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch"))
56+
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)
57+
A′, B′ = _unbatch(A), _unbatch(B)
58+
@inbounds for k in axes(C, 3)
59+
@views mul!(C[:,:,k], $fA(A′[:,:,k]), $fB(B′[:,:,k]))
60+
end
61+
C
3762
end
3863
end

test/batchedmul.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ function bmm_adjtest(a,b; adjA = false, adjB = false)
2222
cat(c...; dims = 3)
2323
end
2424

25-
@testset "Batched Matrix Multiplication" begin
25+
26+
@testset "Batched Matrix Multiplication" for TB in [Float64, Float32]
27+
2628
A = randn(7,5,3)
27-
B = randn(5,7,3)
29+
B = randn(TB, 5,7,3)
2830
C = randn(7,6,3)
2931

3032
@test batched_mul(A, B) == bmm_test(A, B)
@@ -34,7 +36,7 @@ end
3436

3537

3638
cA = randn(Complex{Float64}, 7,5,3)
37-
cB = randn(Complex{Float64}, 5,7,3)
39+
cB = randn(Complex{TB}, 5,7,3)
3840
cC = randn(Complex{Float64}, 7,6,3)
3941

4042
@test batched_mul(cA, cB) == bmm_adjtest(cA, cB)
@@ -44,4 +46,15 @@ end
4446

4547
@test batched_transpose(batched_transpose(A)) == A
4648
@test batched_adjoint(batched_adjoint(cA)) == cA
49+
50+
TBi = TB==Float64 ? Int64 : Int32
51+
iA = rand(1:99, 7,5,3)
52+
iB = TB.(rand(1:99, 5,7,3))
53+
iC = zeros(Int, 7,6,3)
54+
@test batched_mul(iA, iB) == bmm_adjtest(iA, iB)
55+
56+
@test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 2,2,10))
57+
@test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 10,2,2))
58+
@test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2))
59+
4760
end

0 commit comments

Comments
 (0)