Skip to content

Commit f8d53b8

Browse files
committed
Revert to multiplication without conjugate
1 parent b751af8 commit f8d53b8

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

src/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
100100
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
101101
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
102102
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N))
103-
return conj($(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...))
103+
return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)
104104
end
105105
end
106106
end
@@ -113,7 +113,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
113113
cdims::ConvDims; kwargs...) where {xT, yT, N}
114114
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),
115115
channels_out(cdims))
116-
return conj($(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...))
116+
return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
117117
end
118118
end
119119

src/impl/conv_direct.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function conv_direct!(
107107
kproj(kh, kernel_h, fk),
108108
kproj(kd, kernel_d, fk),
109109
c_in, c_out]
110-
dotprod = muladd(x_val, conj(w_val), dotprod)
110+
dotprod = muladd(x_val, w_val, dotprod)
111111
end
112112
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
113113
end
@@ -147,7 +147,7 @@ function conv_direct!(
147147
kproj(kh, kernel_h, fk),
148148
kproj(kd, kernel_d, fk),
149149
c_in, c_out]
150-
dotprod = muladd(x_val, conj(w_val), dotprod)
150+
dotprod = muladd(x_val, w_val, dotprod)
151151
end
152152
end
153153
end
@@ -169,7 +169,7 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`.
169169
function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
170170
w::AbstractArray{wT,5}, cdims::DenseConvDims;
171171
alpha::xT=xT(1), beta=false) where {xT, yT, wT}
172-
w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :])
172+
w = conj(transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]))
173173
dy = predilate(dy, stride(cdims))
174174
ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims),
175175
dilation=dilation(cdims),
@@ -188,7 +188,7 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`.
188188
function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
189189
dy::AbstractArray{yT,5}, cdims::DenseConvDims;
190190
alpha::wT=wT(1), beta=false) where {xT, yT, wT}
191-
x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :])
191+
x = conj(transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]))
192192
dy = transpose_swapbatch(predilate(dy, stride(cdims)))
193193
ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),
194194
stride=dilation(cdims))

src/impl/conv_im2col.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function conv_im2col!(
5151
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
5252
GC.@preserve col_slice w y begin
5353
col_ptr = pointer(col_slice)
54-
w_ptr = pointer(copy(conj(w)))
54+
w_ptr = pointer(w)
5555
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
5656
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
5757
end

test/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ end
373373
end
374374
end
375375
@testset "$(conv)" begin
376-
@test isapprox(ddims(conv(x, w, cdims)), [vec(w)' * vec(x)], rtol = 1.0e-7)
376+
@test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7)
377377
end
378378
end
379379
dy = NNlib.conv(x, w, cdims)
@@ -383,8 +383,8 @@ end
383383
(NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct),
384384
)
385385
@testset "$(∇conv_filter)/$(∇conv_data)" begin
386-
@test isapprox(∇conv_filter(x, dy, cdims), x .* conj(dy), rtol = 1.0e-7)
387-
@test isapprox(∇conv_data(dy, w, cdims), conj(dy) .* w, rtol = 1.0e-7)
386+
@test isapprox(∇conv_filter(x, dy, cdims), conj(x) .* dy, rtol = 1.0e-7)
387+
@test isapprox(∇conv_data(dy, w, cdims), dy .* conj(w), rtol = 1.0e-7)
388388
end
389389
end
390390
end

0 commit comments

Comments
 (0)