|
1 | 1 | # batch-wise matrix multiplication
|
2 | 2 | # wrapper for batched_gemm!
|
3 | 3 |
|
4 |
| -function batchedmul(a::AbstractArray{T, 3}, b::AbstractArray{T, 3}; |
5 |
| - transA::Bool = false, transB::Bool = false) where T |
6 |
| - (bs = size(a, 3)) == size(b, 3) || error("batch size mismatch") |
7 |
| - res = similar(a, size(a, transA ? 2 : 1), size(b, transB ? 1 : 2), bs) |
8 |
| - batched_mul!(res, a, b; transA=transA, transB=transB) |
9 |
| - return res |
10 |
| -end |
| 4 | +include("./batchedadjtrans.jl") |
11 | 5 |
|
12 |
| -function batched_mul!(C::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::AbstractArray{T, 3}; |
13 |
| - transA::Bool = false, transB::Bool = false) where T |
14 |
| - At = transA ? 'T' : 'N' |
15 |
| - Bt = transB ? 'T' : 'N' |
16 |
| - batched_gemm!(At, Bt, one(T), A, B, zero(T), C) |
17 |
| - C |
| 6 | +function batched_mul(A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T |
| 7 | + size(A, 3) == size(B, 3) || throw(DimensionMismatch("batch size mismatch")) |
| 8 | + batched_mul!(similar(A, (size(A, 1), size(B, 2), size(A, 3))), A, B) |
18 | 9 | end
|
19 | 10 |
|
20 |
| -#gradient function for batchedmul |
21 |
| -function ∇batchedmul(Δ::AbstractArray{T, 3}, a::AbstractArray{T, 3}, b::AbstractArray{T, 3}; |
22 |
| - transA::Bool = false, transB::Bool = false) where T |
23 |
| - if transA |
24 |
| - if transB |
25 |
| - (batchedmul(b, Δ; transA=true, transB=true), batchedmul(Δ, a; transA=true, transB=true)) |
26 |
| - else |
27 |
| - (batchedmul(b, Δ; transB=true), batchedmul(a, Δ)) |
28 |
| - end |
29 |
| - else |
30 |
| - if transB |
31 |
| - (batchedmul(Δ, b), batchedmul(Δ, a; transA=true)) |
32 |
| - else |
33 |
| - (batchedmul(Δ, b; transB=true), batchedmul(a, Δ; transA=true)) |
| 11 | +""" |
| 12 | + batched_mul!(C, A, B) -> C |
| 13 | +batched `mul!`. |
| 14 | +""" |
| 15 | +function batched_mul! end |
| 16 | + |
| 17 | +_unbatch(A) = A |
| 18 | +_unbatch(A::BatchedAdjOrTrans) = A.parent |
| 19 | + |
| 20 | +# bmm |
| 21 | +const _BATCHED_MATRIX_LIST = [ |
| 22 | + (:(AbstractArray{T, 3}), 'N'), |
| 23 | + (:(BatchedTranspose{T, <:AbstractArray{T, 3}}), 'T'), |
| 24 | + (:(BatchedAdjoint{T, <:AbstractArray{T, 3}}), 'C') |
| 25 | +] |
| 26 | + |
| 27 | +for (TA, transA) in _BATCHED_MATRIX_LIST, (TB, transB) in _BATCHED_MATRIX_LIST |
| 28 | + @eval begin |
| 29 | + function batched_mul!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where T |
| 30 | + batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C) |
| 31 | + C |
34 | 32 | end
|
| 33 | + |
| 34 | + |
35 | 35 | end
|
36 | 36 | end
|
| 37 | + |
| 38 | +function ∇batched_mul(Δ::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::AbstractArray{T, 3}) where T |
| 39 | + (batched_mul(Δ, batched_transpose(B)), batched_mul(batched_transpose(A), Δ)) |
| 40 | +end |
| 41 | + |
| 42 | + |
| 43 | +function ∇batched_mul(Δ::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T |
| 44 | + (batched_mul(Δ, batched_transpose(B)), batched_mul(A, Δ)) |
| 45 | +end |
| 46 | + |
| 47 | +function ∇batched_mul(Δ::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T |
| 48 | + (batched_mul(Δ, B), batched_mul(batched_transpose(A), Δ)) |
| 49 | +end |
| 50 | + |
| 51 | +function ∇batched_mul(Δ::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T |
| 52 | + (batched_mul(batched_transpose(Δ), batched_transpose(B)), batched_mul(batched_transpose(A), batched_transpose(Δ))) |
| 53 | +end |
0 commit comments