Skip to content

Commit 2bd6d7e

Browse files
committed
Add adjoint for complex number
1 parent 6237de3 commit 2bd6d7e

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

src/NNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module NNlib
33
using Requires, Libdl
44

55
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
6-
softmax, logsoftmax, maxpool, meanpool, batched_mul, batched_transpose
6+
softmax, logsoftmax, maxpool, meanpool, batched_mul, batched_transpose, batched_adjoint
77

88
include("numeric.jl")
99
include("activation.jl")

src/linalg.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,27 @@ for (gemm, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32))
2626
transA, transB, M, N, K,
2727
alpha, A, lda, B, ldb, beta, C, ldc)
2828
end
29+
end
30+
end
2931

30-
## borrow BatchedRoutines.jl
31-
# batched gemm for 3d-array
32-
# C[:,:,i] := alpha*op(A[:,:,i])*op(B[:,:,i]) + beta*C[:,:,i], where:
33-
# i is the specific batch number,
34-
# op(X) is one of op(X) = X, or op(X) = XT, or op(X) = XH,
35-
# alpha and beta are scalars,
36-
# A, B and C are 3d Array:
37-
# op(A) is an m-by-k-by-b 3d Array,
38-
# op(B) is a k-by-n-by-b 3d Array,
39-
# C is an m-by-n-by-b 3d Array.
32+
33+
## borrow BatchedRoutines.jl
34+
# batched gemm for 3d-array
35+
# C[:,:,i] := alpha*op(A[:,:,i])*op(B[:,:,i]) + beta*C[:,:,i], where:
36+
# i is the specific batch number,
37+
# op(X) is one of op(X) = X, or op(X) = XT, or op(X) = XH,
38+
# alpha and beta are scalars,
39+
# A, B and C are 3d Array:
40+
# op(A) is an m-by-k-by-b 3d Array,
41+
# op(B) is a k-by-n-by-b 3d Array,
42+
# C is an m-by-n-by-b 3d Array.
43+
44+
for (gemm, elty) in
45+
((:dgemm_,:Float64),
46+
(:sgemm_,:Float32),
47+
(:zgemm_,:ComplexF64),
48+
(:cgemm_,:ComplexF32))
49+
@eval begin
4050
function batched_gemm!(transA::AbstractChar,
4151
transB::AbstractChar,
4252
alpha::($elty),

test/batchedmul.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ function bmm_test(a,b; transA = false, transB = false)
1010
cat(c...; dims = 3)
1111
end
1212

13+
function bmm_adjtest(a,b; adjA = false, adjB = false)
14+
bs = size(a,3)
15+
c = []
16+
for i = 1:bs
17+
ai = adjA ? adjoint(a[:,:,i]) : a[:,:,i]
18+
bi = adjB ? adjoint(b[:,:,i]) : b[:,:,i]
19+
push!(c, ai*bi)
20+
end
21+
22+
cat(c...; dims = 3)
23+
end
24+
1325
@testset "Batched Matrix Multiplication" begin
1426
A = randn(7,5,3)
1527
B = randn(5,7,3)
@@ -19,4 +31,14 @@ end
1931
@test batched_mul(batched_transpose(A), batched_transpose(B)) == bmm_test(A, B; transA = true, transB = true)
2032
@test batched_mul(batched_transpose(A), C) == bmm_test(A, C; transA = true)
2133
@test batched_mul(A, batched_transpose(A)) == bmm_test(A, A; transB = true)
34+
35+
36+
cA = randn(Complex{Float64}, 7,5,3)
37+
cB = randn(Complex{Float64}, 5,7,3)
38+
cC = randn(Complex{Float64}, 7,6,3)
39+
40+
@test batched_mul(cA, cB) == bmm_adjtest(cA, cB)
41+
@test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) == bmm_adjtest(cA, cB; adjA = true, adjB = true)
42+
@test batched_mul(batched_adjoint(cA), cC) == bmm_adjtest(cA, cC; adjA = true)
43+
@test batched_mul(cA, batched_adjoint(cA)) == bmm_adjtest(cA, cA; adjB = true)
2244
end

0 commit comments

Comments
 (0)