Skip to content

Commit b751af8

Browse files
committed
Convolution should conjugate one of its parameters
1 parent 4b94c88 commit b751af8

File tree

5 files changed

+36
-6
lines changed

5 files changed

+36
-6
lines changed

src/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
100100
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
101101
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
102102
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N))
103-
return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)
103+
return conj($(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...))
104104
end
105105
end
106106
end
@@ -113,7 +113,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
113113
cdims::ConvDims; kwargs...) where {xT, yT, N}
114114
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),
115115
channels_out(cdims))
116-
return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
116+
return conj($(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...))
117117
end
118118
end
119119

src/gemm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ for (gemm, elt) in gemm_datatype_mappings
3535
beta::$(elt), C::Ptr{$elt})
3636
# Convert our compile-time transpose marker to a char for BLAS
3737
convtrans(V::Val{false}) = 'N'
38-
convtrans(V::Val{true}) = 'C'
38+
convtrans(V::Val{true}) = $elt <: Complex ? 'C' : 'T'
3939

4040
if transA == Val(false)
4141
lda = M

src/impl/conv_direct.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function conv_direct!(
107107
kproj(kh, kernel_h, fk),
108108
kproj(kd, kernel_d, fk),
109109
c_in, c_out]
110-
dotprod = muladd(x_val, w_val, dotprod)
110+
dotprod = muladd(x_val, conj(w_val), dotprod)
111111
end
112112
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
113113
end
@@ -147,7 +147,7 @@ function conv_direct!(
147147
kproj(kh, kernel_h, fk),
148148
kproj(kd, kernel_d, fk),
149149
c_in, c_out]
150-
dotprod = muladd(x_val, w_val, dotprod)
150+
dotprod = muladd(x_val, conj(w_val), dotprod)
151151
end
152152
end
153153
end

src/impl/conv_im2col.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function conv_im2col!(
5151
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
5252
GC.@preserve col_slice w y begin
5353
col_ptr = pointer(col_slice)
54-
w_ptr = pointer(w)
54+
w_ptr = pointer(copy(conj(w)))
5555
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
5656
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
5757
end

test/conv.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,36 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
359359
end
360360
end
361361

362+
@testset "Complex Dense Convolution" begin
363+
# For now only 1 dimensional 1x1 convolution
364+
x = reshape(complex.(Float64[1:4;], Float64[1:4;] .+ 1), 1, 4, 1)
365+
w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1)
366+
cdims = DenseConvDims(x, w)
367+
convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]
368+
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack)
369+
for conv in convs
370+
if NNlib.is_nnpack_available()
371+
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(cdims)
372+
continue
373+
end
374+
end
375+
@testset "$(conv)" begin
376+
@test isapprox(ddims(conv(x, w, cdims)), [vec(w)' * vec(x)], rtol = 1.0e-7)
377+
end
378+
end
379+
dy = NNlib.conv(x, w, cdims)
380+
for (∇conv_filter, ∇conv_data) in (
381+
(NNlib.∇conv_filter, NNlib.∇conv_data),
382+
(NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col),
383+
(NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct),
384+
)
385+
@testset "$(∇conv_filter)/$(∇conv_data)" begin
386+
@test isapprox(∇conv_filter(x, dy, cdims), x .* conj(dy), rtol = 1.0e-7)
387+
@test isapprox(∇conv_data(dy, w, cdims), conj(dy) .* w, rtol = 1.0e-7)
388+
end
389+
end
390+
end
391+
362392
if get(ENV, "NNLIB_TEST_FUZZING", "false") == "true"
363393
# @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
364394
@testset "fuzzing" begin

0 commit comments

Comments
 (0)