@@ -15,7 +15,7 @@ function nnp_relu_output(batch_size, channels, input, output, negative_slope, th
15
15
@check ccall ((:nnp_relu_output , libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Cfloat, pthreadpool_t), batch_size, channels, input, output, negative_slope, threadpool)
16
16
end
17
17
18
- function nnp_relu_output (x:: Array{Float32,N} , y:: Array{Float32,N} ; negative_slope:: Float = 0.0 , threadpool = shared_threadpool[]) where {N}
18
+ function nnp_relu_output (x:: Array{Float32,N} , y:: Array{Float32,N} ; negative_slope:: AbstractFloat = 0.0 , threadpool = shared_threadpool[]) where {N}
19
19
# Investigate why the channel and batch dims need to specified like this
20
20
nnp_relu_output (prod (size (x)[N- 1 : N]), prod (size (x)[1 : N- 2 ]), x, y, negative_slope, threadpool)
21
21
y
@@ -25,7 +25,7 @@ function nnp_relu_input_gradient(batch_size, channels, grad_output, input, grad_
25
25
@check ccall ((:nnp_relu_input_gradient , libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Cfloat, pthreadpool_t), batch_size, channels, grad_output, input, grad_input, negative_slope, threadpool)
26
26
end
27
27
28
- function nnp_relu_input_gradient (x:: Array{Float32,N} , dy:: Array{Float32,N} , dx:: Array{Float32,N} ; negative_slope:: Float = 0.0 , threadpool = shared_threadpool[]) where {N}
28
+ function nnp_relu_input_gradient (x:: Array{Float32,N} , dy:: Array{Float32,N} , dx:: Array{Float32,N} ; negative_slope:: AbstractFloat = 0.0 , threadpool = shared_threadpool[]) where {N}
29
29
# Investigate why the channel and batch dims need to specified like this
30
30
nnp_relu_input_gradient (Csize_t (prod (size (x)[N- 1 : N])), prod (size (x)[1 : N- 2 ]), dy, x, dx, negative_slope, threadpool)
31
31
dx
0 commit comments