Skip to content

Commit 52e0310

Browse files
committed
Fix depthwise occasionally spitting out NaN
1 parent 303c900 commit 52e0310

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/impl/depthwiseconv_direct.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ See the docstring for `conv_direct!()` for more on the optional parameters.
2020
"""
2121
function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
2222
w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;
23-
alpha::yT = yT(1), beta = false) where {yT, xT, wT}
23+
alpha::yT=yT(1), beta=false) where {yT, xT, wT}
2424
check_dims(size(x), size(w), size(y), cdims)
2525

2626
width, height, depth = input_size(cdims)
@@ -135,7 +135,7 @@ for each batch and channel independently.
135135
function ∇depthwiseconv_data_direct!(
136136
dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
137137
w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;
138-
alpha::xT=xT(1), beta::xT=xT(0)) where {xT, yT, wT}
138+
alpha::xT=xT(1), beta=false) where {xT, yT, wT}
139139
# We do a separate convolution for each channel in x
140140
@inbounds for cidx in 1:channels_in(cdims)
141141
# For this batch and in-channel, we have a normal transposed convolution
@@ -168,7 +168,7 @@ Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`
168168
function ∇depthwiseconv_filter_direct!(
169169
dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
170170
dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims;
171-
alpha::wT=wT(1),beta::wT=wT(0)) where {xT, yT, wT}
171+
alpha::wT=wT(1),beta=false) where {xT, yT, wT}
172172
# We do a separate convolution for each channel in x
173173
@inbounds for cidx in 1:channels_in(cdims)
174174
# For this batch and in-channel, we have a normal transposed convolution

src/impl/depthwiseconv_im2col.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function depthwiseconv_im2col!(
1414
y::AbstractArray{T,5}, x::AbstractArray{T,5},
1515
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
1616
col::AbstractArray{T,2} = similar(x, im2col_dims(cdims)),
17-
alpha=T(1), beta=T(0)) where T
17+
alpha::T=T(1), beta::T=T(0)) where T
1818
check_dims(size(x), size(w), size(y), cdims)
1919

2020
# This functions exactly the same as conv_im2col!(), except that we shard the
@@ -56,7 +56,7 @@ function ∇depthwiseconv_filter_im2col!(
5656
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
5757
dy::AbstractArray{T,5}, cdims::DepthwiseConvDims;
5858
col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)),
59-
alpha=T(1), beta=T(0)) where T
59+
alpha::T=T(1), beta::T=T(0)) where T
6060
check_dims(size(x), size(dw), size(dy), cdims)
6161

6262
M = prod(kernel_size(cdims))
@@ -96,7 +96,7 @@ function ∇depthwiseconv_data_im2col!(
9696
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
9797
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
9898
col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)),
99-
alpha=T(1), beta=T(0)) where T
99+
alpha::T=T(1), beta::T=T(0)) where T
100100
check_dims(size(dx), size(w), size(dy), cdims)
101101

102102
M = prod(output_size(cdims))

0 commit comments

Comments
 (0)