diff --git a/src/lib/nnlib.jl b/src/lib/nnlib.jl index 3513b54f7..8f71c67dd 100644 --- a/src/lib/nnlib.jl +++ b/src/lib/nnlib.jl @@ -1,5 +1,5 @@ using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, ∇conv_data, ∇depthwiseconv_data, maxpool, meanpool, σ, relu +import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, ∇conv_data, maxpool, meanpool, σ, relu @adjoint function Base.Broadcast.broadcasted(::typeof(relu), x::Numeric) relu.(x), Δ -> (nothing, ifelse.(x .> 0, Δ, zero.(x))) @@ -14,8 +14,8 @@ end @adjoint logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),) -@adjoint NNlib.DenseConvDims(args...; kwargs...) = NNlib.DenseConvDims(args...; kwargs...), _ -> nothing -@adjoint NNlib.DepthwiseConvDims(args...; kwargs...) = NNlib.DepthwiseConvDims(args...; kwargs...), _ -> nothing +@adjoint NNlib.ConvDims(args...; kwargs...) = NNlib.ConvDims(args...; kwargs...), _ -> nothing +# @adjoint NNlib.DepthwiseConvDims(args...; kwargs...) = NNlib.DepthwiseConvDims(args...; kwargs...), _ -> nothing @adjoint NNlib.PoolDims(args...; kwargs...) = NNlib.PoolDims(args...; kwargs...), _ -> nothing @adjoint conv(x, w, cdims; kw...) = @@ -38,25 +38,25 @@ end ) end -@adjoint depthwiseconv(x, w, cdims; kw...) = - depthwiseconv(x, w, cdims; kw...), - Δ -> begin - return ( - NNlib.∇depthwiseconv_data(Δ, w, cdims; kw...), - NNlib.∇depthwiseconv_filter(x, Δ, cdims; kw...), - nothing, - ) - end - -@adjoint ∇depthwiseconv_data(x, w, cdims; kw...) = - ∇depthwiseconv_data(x, w, cdims; kw...), - Δ -> begin - return ( - NNlib.depthwiseconv(Δ, w, cdims; kw...), - NNlib.∇depthwiseconv_filter(Δ, x, cdims; kw...), - nothing, - ) - end +# @adjoint depthwiseconv(x, w, cdims; kw...) = +# depthwiseconv(x, w, cdims; kw...), +# Δ -> begin +# return ( +# NNlib.∇depthwiseconv_data(Δ, w, cdims; kw...), +# NNlib.∇depthwiseconv_filter(x, Δ, cdims; kw...), +# nothing, +# ) +# end +# +# @adjoint ∇depthwiseconv_data(x, w, cdims; kw...) = +# ∇depthwiseconv_data(x, w, cdims; kw...), +# Δ -> begin +# return ( +# NNlib.depthwiseconv(Δ, w, cdims; kw...), +# NNlib.∇depthwiseconv_filter(Δ, x, cdims; kw...), +# nothing, +# ) +# end @adjoint function maxpool(x, pdims; kw...) y = maxpool(x, pdims; kw...)