Skip to content

Commit aa86827

Browse files
authored
Merge pull request #389 from zsoerenm/complex-convolution-fix
Fix gradient of convolution for complex values
2 parents 51595b7 + 9b6d233 commit aa86827

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

src/gemm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ for (gemm, elt) in gemm_datatype_mappings
3535
beta::$(elt), C::Ptr{$elt})
3636
# Convert our compile-time transpose marker to a char for BLAS
3737
convtrans(V::Val{false}) = 'N'
38-
convtrans(V::Val{true}) = 'T'
38+
convtrans(V::Val{true}) = 'C'
3939

4040
if transA == Val(false)
4141
lda = M

src/impl/conv_direct.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`.
169169
function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
170170
w::AbstractArray{wT,5}, cdims::DenseConvDims;
171171
alpha::xT=xT(1), beta=false) where {xT, yT, wT}
172-
w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :])
172+
w = conj(transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]))
173173
dy = predilate(dy, stride(cdims))
174174
ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims),
175175
dilation=dilation(cdims),
@@ -188,7 +188,7 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`.
188188
function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
189189
dy::AbstractArray{yT,5}, cdims::DenseConvDims;
190190
alpha::wT=wT(1), beta=false) where {xT, yT, wT}
191-
x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :])
191+
x = conj(transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]))
192192
dy = transpose_swapbatch(predilate(dy, stride(cdims)))
193193
ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),
194194
stride=dilation(cdims))

test/conv.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,36 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
359359
end
360360
end
361361

362+
@testset "Complex Dense Convolution" begin
363+
# For now only 1 dimensional 1x1 convolution
364+
x = reshape(complex.(Float64[1:4;], Float64[1:4;] .+ 1), 1, 4, 1)
365+
w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1)
366+
cdims = DenseConvDims(x, w)
367+
convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]
368+
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack)
369+
for conv in convs
370+
if NNlib.is_nnpack_available()
371+
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(cdims)
372+
continue
373+
end
374+
end
375+
@testset "$(conv)" begin
376+
@test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7)
377+
end
378+
end
379+
dy = NNlib.conv(x, w, cdims)
380+
for (∇conv_filter, ∇conv_data) in (
381+
(NNlib.∇conv_filter, NNlib.∇conv_data),
382+
(NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col),
383+
(NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct),
384+
)
385+
@testset "$(∇conv_filter)/$(∇conv_data)" begin
386+
@test isapprox(∇conv_filter(x, dy, cdims), conj(x) .* dy, rtol = 1.0e-7)
387+
@test isapprox(∇conv_data(dy, w, cdims), dy .* conj(w), rtol = 1.0e-7)
388+
end
389+
end
390+
end
391+
362392
if get(ENV, "NNLIB_TEST_FUZZING", "false") == "true"
363393
# @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
364394
@testset "fuzzing" begin

0 commit comments

Comments
 (0)